Skip to content

Commit fdeceaa

Browse files
committed
Replace shape-manipulating subgraphs with ComputeShape operator
Use symbolic tensor values produced by shape inference to replace subgraphs in models with a `ComputeShape` operator. In the Llama 3 model, this reduces the total number of operators run on each inference pass from 1740 to 1192 (-31%). This reduces interpreter overhead, although the impact is usually small because shape-manipulating subgraphs are operating on small values. More importantly, it removes more potentially fusion-blocking value consumers from the graph. This generalizes the idea of `ComputeShapeFusion`.
1 parent a155a91 commit fdeceaa

File tree

5 files changed

+167
-219
lines changed

5 files changed

+167
-219
lines changed

src/ops/compute_shape.rs

Lines changed: 39 additions & 19 deletions
Original file line numberDiff line numberDiff line change
@@ -23,17 +23,28 @@ pub struct SymbolInfo {
2323
pub axis: u32,
2424
}
2525

26-
/// Compute a tensor shape from a combination of static values and the dynamic
27-
/// shapes of inputs.
26+
#[derive(Debug)]
27+
pub enum SymExprKind {
28+
Scalar(SymExpr),
29+
Vector(Vec<SymExpr>),
30+
}
31+
32+
/// Produce a tensor by evaluating symbolic expressions.
33+
///
34+
/// The symbolic expressions can include named symbols whose values are obtained
35+
/// from the dimension sizes of input tensors.
36+
///
37+
/// This operator can replace subgraphs in ONNX models which extract and
38+
/// manipulate tensor shapes.
2839
#[derive(Debug)]
2940
pub struct ComputeShape {
3041
/// Specifies how to map dimension sizes of inputs to symbols used by the
3142
/// `shape` field.
3243
pub symbols: Vec<SymbolInfo>,
3344

34-
/// Specifies the symbolic expression to evaluate for each position in the
35-
/// output vector.
36-
pub shape: Vec<SymExpr>,
45+
/// Specifies the rank of the output tensor and the symbolic expression to
46+
/// evaluate for each element.
47+
pub elements: SymExprKind,
3748
}
3849

3950
impl Operator for ComputeShape {
@@ -62,16 +73,26 @@ impl Operator for ComputeShape {
6273
.collect::<Result<Vec<_>, _>>()?;
6374
let symbols = SymbolMap::new(&symbols);
6475

65-
let output = self
66-
.shape
67-
.iter()
68-
.map(|expr| {
69-
expr.eval(&symbols)
70-
.map_err(|_| OpError::InvalidValue("Failed to evaluate symbolic shape"))
71-
})
72-
.collect::<Result<Vec<i32>, _>>()?;
76+
let output = match &self.elements {
77+
SymExprKind::Scalar(expr) => {
78+
let item = expr
79+
.eval(&symbols)
80+
.map_err(|_| OpError::InvalidValue("Failed to evaluate symbolic shape"))?;
81+
Tensor::from(item)
82+
}
83+
SymExprKind::Vector(shape) => {
84+
let output = shape
85+
.iter()
86+
.map(|expr| {
87+
expr.eval(&symbols)
88+
.map_err(|_| OpError::InvalidValue("Failed to evaluate symbolic shape"))
89+
})
90+
.collect::<Result<Vec<i32>, _>>()?;
91+
Tensor::from(output)
92+
}
93+
};
7394

74-
Tensor::from(output).into_op_result()
95+
output.into_op_result()
7596
}
7697

7798
fn output_types(&self, _ctx: &OutputTypesContext) -> Option<OutputTypeList> {
@@ -108,7 +129,7 @@ mod tests {
108129
},
109130
]
110131
.to_vec(),
111-
shape: [
132+
elements: super::SymExprKind::Vector(vec![
112133
SymExpr::Value(3),
113134
SymExpr::Var(
114135
Symbol {
@@ -125,8 +146,7 @@ mod tests {
125146
}
126147
.into(),
127148
),
128-
]
129-
.into(),
149+
]),
130150
};
131151
let result: NdTensor<i32, 1> = op.run_simple((input_a.view(), input_b.view())).unwrap();
132152

@@ -140,7 +160,7 @@ mod tests {
140160
axis: 0,
141161
}]
142162
.into(),
143-
shape: Vec::new(),
163+
elements: super::SymExprKind::Vector(Vec::new()),
144164
};
145165
let result: Result<NdTensor<i32, 1>, _> = op.run_simple(input_a.view());
146166
assert_eq!(result.err().unwrap(), OpError::MissingInputs);
@@ -153,7 +173,7 @@ mod tests {
153173
axis: 3,
154174
}]
155175
.into(),
156-
shape: Vec::new(),
176+
elements: super::SymExprKind::Vector(Vec::new()),
157177
};
158178
let result: Result<NdTensor<i32, 1>, _> = op.run_simple(input_a.view());
159179
assert_eq!(

src/ops/mod.rs

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -72,7 +72,7 @@ pub(crate) use {
7272
Add, And, Div, Equal, Greater, GreaterOrEqual, Less, LessOrEqual, Mod, Mul, Or, Pow, Sub,
7373
Where, Xor,
7474
},
75-
compute_shape::{ComputeShape, SymbolInfo},
75+
compute_shape::{ComputeShape, SymExprKind, SymbolInfo},
7676
concat::{Concat, Tile},
7777
control_flow::{If, Loop},
7878
conv::{Conv, ConvInteger},

src/optimize.rs

Lines changed: 121 additions & 25 deletions
Original file line numberDiff line numberDiff line change
@@ -1,21 +1,23 @@
11
use std::any::Any;
2+
use std::collections::HashMap;
23
use std::error::Error;
34
use std::fmt::{Display, Formatter};
45
use std::sync::Arc;
56

7+
use rten_shape_inference::SymExpr;
68
use rten_tensor::Tensor;
79
use rustc_hash::FxHashSet;
810
use smallvec::SmallVec;
911

10-
use crate::Value;
1112
use crate::env::str_as_bool;
1213
use crate::graph::{
1314
CaptureEnv, Constant, ConstantNode, ConstantNodeData, Graph, Node, NodeId, OperatorNode,
1415
PlanOptions, RunError,
1516
};
1617
use crate::infer_shapes::{InferError, InferShapeOptions, infer_shapes};
1718
use crate::operator::Operator;
18-
use crate::ops::Identity;
19+
use crate::ops::{ComputeShape, Identity, SymExprKind, SymbolInfo};
20+
use crate::{Dimension, Value};
1921

2022
mod diagnostics;
2123
mod fusions;
@@ -24,12 +26,11 @@ mod pattern_matcher;
2426
use diagnostics::{DiagnosticLevel, Diagnostics};
2527

2628
use fusions::{
27-
AddSoftmaxFusion, ApproxGeluFusion, CastElimination, ComputeShapeFusion, Fusion, FusionError,
28-
FusionVisitor, GeluFusion, GroupedQueryAttentionMatMulFusion, IdentityFusion,
29-
LayerNormalizationFusion, MatMulAddFusion, MatMulIntegerToFloatFusion, MatMulScaleFusion,
30-
PatternFusion, ReciprocalFusion, ReduceMeanAxesFusion, RepeatInterleaveFusion,
31-
RmsNormalizationFusion, SafeSoftmaxFusion, ShapeSliceToConstant, SiluFusion, SwishFusion,
32-
TransposeFusion,
29+
AddSoftmaxFusion, ApproxGeluFusion, CastElimination, Fusion, FusionError, FusionVisitor,
30+
GeluFusion, GroupedQueryAttentionMatMulFusion, IdentityFusion, LayerNormalizationFusion,
31+
MatMulAddFusion, MatMulIntegerToFloatFusion, MatMulScaleFusion, PatternFusion,
32+
ReciprocalFusion, ReduceMeanAxesFusion, RepeatInterleaveFusion, RmsNormalizationFusion,
33+
SafeSoftmaxFusion, ShapeSliceToConstant, SiluFusion, SwishFusion, TransposeFusion,
3334
};
3435

3536
/// Errors that occur while applying graph optimizations.
@@ -424,30 +425,64 @@ impl GraphOptimizer {
424425
if let Some(infer_opts) = opts.infer_shapes {
425426
let infer_result = infer_shapes(&graph_mut.graph, infer_opts)
426427
.map_err(OptimizeError::InferShapesError)?;
427-
let const_ids: Vec<Option<NodeId>> = infer_result
428+
429+
let sym_map = symbol_map(&graph_mut.graph);
430+
431+
// IDs of constants and value nodes that replace value IDs in the
432+
// input.
433+
//
434+
// Where shape inference infers that a value node has a fixed value,
435+
// it can be replaced with a constant. Where it infers the value
436+
// can be produced by evaluating a symbolic expression, replace the
437+
// value with the output of a `ComputeShape` node.
438+
let replacement_ids: Vec<Option<NodeId>> = infer_result
428439
.values
429440
.iter()
430441
.map(|expr| {
431-
let constant = expr.to_constant()?;
432-
let tensor = match constant {
433-
rten_shape_inference::Constant::Scalar(x) => Tensor::from(x),
434-
rten_shape_inference::Constant::Vector(vec) => Tensor::from(vec),
435-
};
436-
let const_id = graph_mut.add_constant(None, tensor.into_arc());
437-
Some(const_id)
442+
if let Some(constant) = expr.to_constant() {
443+
let tensor = match constant {
444+
rten_shape_inference::Constant::Scalar(x) => Tensor::from(x),
445+
rten_shape_inference::Constant::Vector(vec) => Tensor::from(vec),
446+
};
447+
let const_id = graph_mut.add_constant(None, tensor.into_arc());
448+
Some(const_id)
449+
} else if let Some(values) = expr.values() {
450+
let (symbols, input_ids) = compute_shape_inputs(values, &sym_map);
451+
let op = ComputeShape {
452+
symbols,
453+
elements: if let Some(expr) = expr.as_vector() {
454+
SymExprKind::Vector(expr.to_vec())
455+
} else {
456+
SymExprKind::Scalar(values[0].clone())
457+
},
458+
};
459+
let input_ids: Vec<_> = input_ids.into_iter().map(Some).collect();
460+
let output_id = graph_mut.graph.add_value(None, None, None);
461+
graph_mut.add_operator(None, Arc::new(op), &input_ids, &[Some(output_id)]);
462+
Some(output_id)
463+
} else {
464+
None
465+
}
438466
})
439467
.collect();
440468

469+
let mut removed_nodes = Vec::new();
441470
for (value_id, shape_index) in &infer_result.shapes {
442-
if let Some(const_id) = const_ids[*shape_index] {
443-
graph_mut.replace_value(*value_id, const_id);
471+
if let Some(new_value_id) = replacement_ids[*shape_index] {
472+
graph_mut.replace_value(*value_id, new_value_id);
473+
removed_nodes.push(*value_id);
474+
if let Some((src_op_id, _src_op)) = graph_mut.graph.get_source_node(*value_id) {
475+
removed_nodes.push(src_op_id);
476+
}
444477
} else if let Some(dims) = infer_result.dims(*value_id) {
445478
graph_mut.graph.update_value_shape(*value_id, dims);
446479
}
447480
}
448481
for (value_id, value_type) in infer_result.types {
449482
graph_mut.graph.update_value_type(value_id, value_type);
450483
}
484+
485+
graph_mut.graph.remove_nodes(&removed_nodes);
451486
}
452487

453488
// "Early" fusions. These are fusions which can benefit constant
@@ -464,13 +499,6 @@ impl GraphOptimizer {
464499
early_fusions.push(CastElimination {});
465500
early_fusions.push(IdentityFusion {});
466501

467-
// Fusion which replaces Shape nodes using shape inference metadata.
468-
//
469-
// This can free up the source of the Shape's input to be included in
470-
// other fusions. If all dimensions have static sizes, constant prop
471-
// will remove the ComputeShape node and downstream nodes.
472-
early_fusions.push(ComputeShapeFusion {});
473-
474502
self.apply_fusions(&mut graph_mut, early_fusions.visitors(), &diag)?;
475503

476504
// Constant propagation.
@@ -736,5 +764,73 @@ impl FusionList {
736764
}
737765
}
738766

767+
/// Create a map of dimension name to (value_id, dim), for use with
768+
/// [`compute_shape_inputs`].
769+
fn symbol_map(graph: &Graph) -> HashMap<String, (NodeId, u32)> {
770+
let mut map = HashMap::new();
771+
772+
for id in graph.input_ids() {
773+
let Some(node) = graph.get_node(*id) else {
774+
continue;
775+
};
776+
let Some(shape) = node.shape() else {
777+
continue;
778+
};
779+
780+
for (dim_idx, dim) in shape.iter().enumerate() {
781+
match dim {
782+
Dimension::Symbolic(name) => {
783+
if !map.contains_key(name) {
784+
map.insert(name.to_string(), (*id, dim_idx as u32));
785+
}
786+
}
787+
Dimension::Fixed(_) => {}
788+
}
789+
}
790+
}
791+
792+
map
793+
}
794+
795+
/// Generate the input ID list and symbol_name => (input_id, axis) mappings for
796+
/// a [`ComputeShape`] operator.
797+
fn compute_shape_inputs(
798+
elements: &[SymExpr],
799+
syms: &HashMap<String, (NodeId, u32)>,
800+
) -> (Vec<SymbolInfo>, Vec<NodeId>) {
801+
let vars = elements
802+
.iter()
803+
.flat_map(|expr| expr.iter())
804+
.filter_map(|node| match node {
805+
SymExpr::Var(sym) => Some(sym.name.as_ref()),
806+
_ => None,
807+
});
808+
809+
let mut input_ids = Vec::new();
810+
let mut symbols = Vec::<SymbolInfo>::new();
811+
for var in vars {
812+
if symbols.iter().any(|s| s.name == var) {
813+
continue;
814+
}
815+
let Some((input_id, axis)) = syms.get(var) else {
816+
continue;
817+
};
818+
let input_id = if let Some(idx) = input_ids.iter().position(|id| id == input_id) {
819+
idx
820+
} else {
821+
let idx = input_ids.len();
822+
input_ids.push(*input_id);
823+
idx
824+
};
825+
symbols.push(SymbolInfo {
826+
name: var.to_string(),
827+
input: input_id as u32,
828+
axis: *axis,
829+
});
830+
}
831+
832+
(symbols, input_ids)
833+
}
834+
739835
#[cfg(test)]
740836
mod tests;

0 commit comments

Comments
 (0)