Skip to content

Pre-transpose constant MatMul operand #315

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Open
wants to merge 4 commits into
base: main
Choose a base branch
from
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
17 changes: 17 additions & 0 deletions rten-convert/rten_convert/converter.py
Original file line number Diff line number Diff line change
Expand Up @@ -27,6 +27,9 @@
AttributeValue = int | float | str | list[int]


EMITTED_WARNINGS: set[str] = set()


@dataclass
class Metadata:
"""
Expand Down Expand Up @@ -534,6 +537,12 @@ def op_node_from_onnx_operator(
attr_reader.check_attr("input_forget", "int", 0)
attr_reader.check_attr("layout", "int", 0)

case "MatMul":
b = constant_nodes.get(onnx_op.input[-1])
if b and len(b.shape) == 2 and b.shape[-1] > 1:
b.data = np.ascontiguousarray(b.data.transpose())
b.strides = [1, b.shape[0]]

case "MaxPool":
attrs = sg.MaxPoolAttrsT()
kernel_shape = attr_reader.require_attr("kernel_shape", "ints")
Expand Down Expand Up @@ -893,6 +902,12 @@ def build_constant_node(
shape_vec = write_vec(
builder, sg.ConstantNodeStartShapeVector, constant.shape, "u32"
)
if constant.strides is not None:
strides_vec = write_vec(
builder, sg.ConstantNodeStartStridesVector, constant.strides, "u32"
)
else:
strides_vec = None
n_elems = reduce(mul, constant.shape, 1)
assert n_elems == constant.data.size, "constant shape does not match element count"

Expand Down Expand Up @@ -952,6 +967,8 @@ def build_constant_node(
sg.ConstantNodeStart(builder)
sg.ConstantNodeAddShape(builder, shape_vec)
sg.ConstantNodeAddDtype(builder, dtype)
if strides_vec:
sg.ConstantNodeAddStrides(builder, strides_vec)

if inline_data:
sg.ConstantNodeAddDataType(builder, inline_data_type)
Expand Down
4 changes: 3 additions & 1 deletion rten-convert/rten_convert/graph.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,7 +5,7 @@
format used by .rten models.
"""

from typing import Any
from typing import Any, Optional

import numpy as np

Expand All @@ -27,12 +27,14 @@ class ConstantNode(Node):
"""

shape: list[int]
strides: Optional[list[int]]
data: np.ndarray

def __init__(self, name: str, shape: list[int], data: np.ndarray):
super().__init__(name)
self.shape = shape
self.data = data
self.strides = None

shape_numel = np.prod(shape)
if shape_numel != data.size:
Expand Down
53 changes: 52 additions & 1 deletion rten-convert/rten_convert/schema_generated.py
Original file line number Diff line number Diff line change
Expand Up @@ -6153,8 +6153,35 @@ def DataOffset(self):
return self._tab.Get(flatbuffers.number_types.Uint64Flags, o + self._tab.Pos)
return None

# ConstantNode
def Strides(self, j):
o = flatbuffers.number_types.UOffsetTFlags.py_type(self._tab.Offset(14))
if o != 0:
a = self._tab.Vector(o)
return self._tab.Get(flatbuffers.number_types.Uint32Flags, a + flatbuffers.number_types.UOffsetTFlags.py_type(j * 4))
return 0

# ConstantNode
def StridesAsNumpy(self):
o = flatbuffers.number_types.UOffsetTFlags.py_type(self._tab.Offset(14))
if o != 0:
return self._tab.GetVectorAsNumpy(flatbuffers.number_types.Uint32Flags, o)
return 0

# ConstantNode
def StridesLength(self):
o = flatbuffers.number_types.UOffsetTFlags.py_type(self._tab.Offset(14))
if o != 0:
return self._tab.VectorLen(o)
return 0

# ConstantNode
def StridesIsNone(self):
o = flatbuffers.number_types.UOffsetTFlags.py_type(self._tab.Offset(14))
return o == 0

def ConstantNodeStart(builder):
builder.StartObject(5)
builder.StartObject(6)

def ConstantNodeAddShape(builder, shape):
builder.PrependUOffsetTRelativeSlot(0, flatbuffers.number_types.UOffsetTFlags.py_type(shape), 0)
Expand All @@ -6174,6 +6201,12 @@ def ConstantNodeAddDtype(builder, dtype):
def ConstantNodeAddDataOffset(builder, dataOffset):
builder.PrependUint64Slot(4, dataOffset, None)

def ConstantNodeAddStrides(builder, strides):
builder.PrependUOffsetTRelativeSlot(5, flatbuffers.number_types.UOffsetTFlags.py_type(strides), 0)

def ConstantNodeStartStridesVector(builder, numElems):
return builder.StartVector(4, numElems, 4)

def ConstantNodeEnd(builder):
return builder.EndObject()

Expand All @@ -6192,6 +6225,7 @@ def __init__(self):
self.data = None # type: Union[None, FloatDataT, Int32DataT, Int8DataT, UInt8DataT]
self.dtype = None # type: Optional[int]
self.dataOffset = None # type: Optional[int]
self.strides = None # type: List[int]

@classmethod
def InitFromBuf(cls, buf, pos):
Expand Down Expand Up @@ -6225,6 +6259,13 @@ def _UnPack(self, constantNode):
self.data = ConstantDataCreator(self.dataType, constantNode.Data())
self.dtype = constantNode.Dtype()
self.dataOffset = constantNode.DataOffset()
if not constantNode.StridesIsNone():
if np is None:
self.strides = []
for i in range(constantNode.StridesLength()):
self.strides.append(constantNode.Strides(i))
else:
self.strides = constantNode.StridesAsNumpy()

# ConstantNodeT
def Pack(self, builder):
Expand All @@ -6238,6 +6279,14 @@ def Pack(self, builder):
shape = builder.EndVector()
if self.data is not None:
data = self.data.Pack(builder)
if self.strides is not None:
if np is not None and type(self.strides) is np.ndarray:
strides = builder.CreateNumpyVector(self.strides)
else:
ConstantNodeStartStridesVector(builder, len(self.strides))
for i in reversed(range(len(self.strides))):
builder.PrependUint32(self.strides[i])
strides = builder.EndVector()
ConstantNodeStart(builder)
if self.shape is not None:
ConstantNodeAddShape(builder, shape)
Expand All @@ -6246,6 +6295,8 @@ def Pack(self, builder):
ConstantNodeAddData(builder, data)
ConstantNodeAddDtype(builder, self.dtype)
ConstantNodeAddDataOffset(builder, self.dataOffset)
if self.strides is not None:
ConstantNodeAddStrides(builder, strides)
constantNode = ConstantNodeEnd(builder)
return constantNode

Expand Down
68 changes: 58 additions & 10 deletions src/model.rs
Original file line number Diff line number Diff line change
Expand Up @@ -582,6 +582,9 @@ impl Model {
tensor_data_offset: Option<u64>,
) -> Result<NodeId, ModelLoadError> {
let shape: Vec<usize> = constant.shape().iter().map(|x| x as usize).collect();
let strides: Option<Vec<usize>> = constant
.strides()
.map(|strides| strides.iter().map(|x| x as usize).collect());

if let Some(data_offset) = constant.data_offset() {
// Constant data is stored outside the model buffer, in the same file.
Expand All @@ -595,23 +598,39 @@ impl Model {

let graph_node = match constant.dtype() {
Some(sg::ConstantDataType::Int32) => {
let const_data =
constant_data_from_storage_offset::<i32>(storage, &shape, data_offset)?;
let const_data = constant_data_from_storage_offset::<i32>(
storage,
&shape,
strides.as_deref(),
data_offset,
)?;
graph.add_constant(name, const_data)
}
Some(sg::ConstantDataType::Float32) => {
let const_data =
constant_data_from_storage_offset::<f32>(storage, &shape, data_offset)?;
let const_data = constant_data_from_storage_offset::<f32>(
storage,
&shape,
strides.as_deref(),
data_offset,
)?;
graph.add_constant(name, const_data)
}
Some(sg::ConstantDataType::Int8) => {
let const_data =
constant_data_from_storage_offset::<i8>(storage, &shape, data_offset)?;
let const_data = constant_data_from_storage_offset::<i8>(
storage,
&shape,
strides.as_deref(),
data_offset,
)?;
graph.add_constant(name, const_data)
}
Some(sg::ConstantDataType::UInt8) => {
let const_data =
constant_data_from_storage_offset::<u8>(storage, &shape, data_offset)?;
let const_data = constant_data_from_storage_offset::<u8>(
storage,
&shape,
strides.as_deref(),
data_offset,
)?;
graph.add_constant(name, const_data)
}
_ => {
Expand Down Expand Up @@ -860,6 +879,7 @@ fn cast_le_bytes<T: Pod>(bytes: &[u8]) -> Option<&[T]> {
fn constant_data_from_storage_offset<T: LeBytes + Pod>(
storage: &Arc<ConstantStorage>,
shape: &[usize],
strides: Option<&[usize]>,
offset: usize,
) -> Result<ConstantNodeData<T>, ModelLoadError> {
let n_elements: usize = shape.iter().product();
Expand All @@ -874,14 +894,42 @@ fn constant_data_from_storage_offset<T: LeBytes + Pod>(
if let Some(elements) = cast_le_bytes(bytes) {
let storage =
ArcSlice::new(storage.clone(), elements).expect("storage does not contain data");
let const_data: ConstantNodeData<T> = ArcTensorView::from_data(shape, storage).into();
let const_data: ConstantNodeData<T> = if let Some(strides) = strides {
ArcTensorView::from_data_with_strides(shape, storage, strides)
.map_err(|_| {
ModelLoadError::GraphError(
format!(
"Graph constant strides {:?} incompatible with shape {:?}",
strides, shape
)
.into(),
)
})?
.into()
} else {
ArcTensorView::from_data(shape, storage).into()
};
Ok(const_data)
} else {
let data: Vec<T> = bytes
.chunks(std::mem::size_of::<T>())
.map(|chunk| T::from_le_bytes(chunk.try_into().unwrap()))
.collect();
Ok(Tensor::from_data(shape, data).into())
Ok(if let Some(strides) = strides {
Tensor::from_data_with_strides(shape, data, strides)
.map_err(|_| {
ModelLoadError::GraphError(
format!(
"Graph constant strides {:?} incompatible with shape {:?}",
strides, shape
)
.into(),
)
})?
.into()
} else {
Tensor::from_data(shape, data).into()
})
}
}

Expand Down
2 changes: 2 additions & 0 deletions src/model_builder.rs
Original file line number Diff line number Diff line change
Expand Up @@ -307,6 +307,7 @@ impl<'mb, 'a> GraphBuilder<'mb, 'a> {

sg::ConstantNodeArgs {
shape: Some(shape_vec),
strides: None,
data_type: sg::ConstantData::NONE,
data: None,
data_offset: Some(offset),
Expand All @@ -318,6 +319,7 @@ impl<'mb, 'a> GraphBuilder<'mb, 'a> {

sg::ConstantNodeArgs {
shape: Some(shape_vec),
strides: None,
data_type: inline_dtype,
data: Some(data),
data_offset: None,
Expand Down
4 changes: 4 additions & 0 deletions src/schema.fbs
Original file line number Diff line number Diff line change
Expand Up @@ -579,6 +579,10 @@ table ConstantNode {
// Offset of tensor data from the start of the tensor data segment in the
// model file. Null if the tensor data is stored inline.
data_offset:uint64 = null;

// Custom strides for each dimension. This enables pre-transposing weights.
// If not specified the strides default to contiguous.
strides:[uint];
}

// Dimension of a ValueNode's shape. This can be either a fixed value or a
Expand Down
30 changes: 30 additions & 0 deletions src/schema_generated.rs
Original file line number Diff line number Diff line change
Expand Up @@ -10489,6 +10489,7 @@ impl<'a> ConstantNode<'a> {
pub const VT_DATA: flatbuffers::VOffsetT = 8;
pub const VT_DTYPE: flatbuffers::VOffsetT = 10;
pub const VT_DATA_OFFSET: flatbuffers::VOffsetT = 12;
pub const VT_STRIDES: flatbuffers::VOffsetT = 14;

#[inline]
pub unsafe fn init_from_table(table: flatbuffers::Table<'a>) -> Self {
Expand All @@ -10503,6 +10504,9 @@ impl<'a> ConstantNode<'a> {
if let Some(x) = args.data_offset {
builder.add_data_offset(x);
}
if let Some(x) = args.strides {
builder.add_strides(x);
}
if let Some(x) = args.data {
builder.add_data(x);
}
Expand Down Expand Up @@ -10572,6 +10576,19 @@ impl<'a> ConstantNode<'a> {
unsafe { self._tab.get::<u64>(ConstantNode::VT_DATA_OFFSET, None) }
}
#[inline]
pub fn strides(&self) -> Option<flatbuffers::Vector<'a, u32>> {
// Safety:
// Created from valid Table for this object
// which contains a valid value in this slot
unsafe {
self._tab
.get::<flatbuffers::ForwardsUOffset<flatbuffers::Vector<'a, u32>>>(
ConstantNode::VT_STRIDES,
None,
)
}
}
#[inline]
#[allow(non_snake_case)]
pub fn data_as_float_data(&self) -> Option<FloatData<'a>> {
if self.data_type() == ConstantData::FloatData {
Expand Down Expand Up @@ -10677,6 +10694,11 @@ impl flatbuffers::Verifiable for ConstantNode<'_> {
)?
.visit_field::<ConstantDataType>("dtype", Self::VT_DTYPE, false)?
.visit_field::<u64>("data_offset", Self::VT_DATA_OFFSET, false)?
.visit_field::<flatbuffers::ForwardsUOffset<flatbuffers::Vector<'_, u32>>>(
"strides",
Self::VT_STRIDES,
false,
)?
.finish();
Ok(())
}
Expand All @@ -10687,6 +10709,7 @@ pub struct ConstantNodeArgs<'a> {
pub data: Option<flatbuffers::WIPOffset<flatbuffers::UnionWIPOffset>>,
pub dtype: Option<ConstantDataType>,
pub data_offset: Option<u64>,
pub strides: Option<flatbuffers::WIPOffset<flatbuffers::Vector<'a, u32>>>,
}
impl<'a> Default for ConstantNodeArgs<'a> {
#[inline]
Expand All @@ -10697,6 +10720,7 @@ impl<'a> Default for ConstantNodeArgs<'a> {
data: None,
dtype: None,
data_offset: None,
strides: None,
}
}
}
Expand Down Expand Up @@ -10735,6 +10759,11 @@ impl<'a: 'b, 'b, A: flatbuffers::Allocator + 'a> ConstantNodeBuilder<'a, 'b, A>
.push_slot_always::<u64>(ConstantNode::VT_DATA_OFFSET, data_offset);
}
#[inline]
pub fn add_strides(&mut self, strides: flatbuffers::WIPOffset<flatbuffers::Vector<'b, u32>>) {
self.fbb_
.push_slot_always::<flatbuffers::WIPOffset<_>>(ConstantNode::VT_STRIDES, strides);
}
#[inline]
pub fn new(
_fbb: &'b mut flatbuffers::FlatBufferBuilder<'a, A>,
) -> ConstantNodeBuilder<'a, 'b, A> {
Expand Down Expand Up @@ -10805,6 +10834,7 @@ impl core::fmt::Debug for ConstantNode<'_> {
};
ds.field("dtype", &self.dtype());
ds.field("data_offset", &self.data_offset());
ds.field("strides", &self.strides());
ds.finish()
}
}
Expand Down