11use std:: any:: Any ;
2+ use std:: collections:: HashMap ;
23use std:: error:: Error ;
34use std:: fmt:: { Display , Formatter } ;
45use std:: sync:: Arc ;
56
7+ use rten_shape_inference:: SymExpr ;
68use rten_tensor:: Tensor ;
79use rustc_hash:: FxHashSet ;
810use smallvec:: SmallVec ;
911
10- use crate :: Value ;
1112use crate :: env:: str_as_bool;
1213use crate :: graph:: {
1314 CaptureEnv , Constant , ConstantNode , ConstantNodeData , Graph , Node , NodeId , OperatorNode ,
1415 PlanOptions , RunError ,
1516} ;
1617use crate :: infer_shapes:: { InferError , InferShapeOptions , infer_shapes} ;
1718use crate :: operator:: Operator ;
18- use crate :: ops:: Identity ;
19+ use crate :: ops:: { ComputeShape , Identity , SymExprKind , SymbolInfo } ;
20+ use crate :: { Dimension , Value } ;
1921
2022mod diagnostics;
2123mod fusions;
@@ -24,12 +26,11 @@ mod pattern_matcher;
2426use diagnostics:: { DiagnosticLevel , Diagnostics } ;
2527
2628use fusions:: {
27- AddSoftmaxFusion , ApproxGeluFusion , CastElimination , ComputeShapeFusion , Fusion , FusionError ,
28- FusionVisitor , GeluFusion , GroupedQueryAttentionMatMulFusion , IdentityFusion ,
29- LayerNormalizationFusion , MatMulAddFusion , MatMulIntegerToFloatFusion , MatMulScaleFusion ,
30- PatternFusion , ReciprocalFusion , ReduceMeanAxesFusion , RepeatInterleaveFusion ,
31- RmsNormalizationFusion , SafeSoftmaxFusion , ShapeSliceToConstant , SiluFusion , SwishFusion ,
32- TransposeFusion ,
29+ AddSoftmaxFusion , ApproxGeluFusion , CastElimination , Fusion , FusionError , FusionVisitor ,
30+ GeluFusion , GroupedQueryAttentionMatMulFusion , IdentityFusion , LayerNormalizationFusion ,
31+ MatMulAddFusion , MatMulIntegerToFloatFusion , MatMulScaleFusion , PatternFusion ,
32+ ReciprocalFusion , ReduceMeanAxesFusion , RepeatInterleaveFusion , RmsNormalizationFusion ,
33+ SafeSoftmaxFusion , ShapeSliceToConstant , SiluFusion , SwishFusion , TransposeFusion ,
3334} ;
3435
3536/// Errors that occur while applying graph optimizations.
@@ -424,30 +425,64 @@ impl GraphOptimizer {
424425 if let Some ( infer_opts) = opts. infer_shapes {
425426 let infer_result = infer_shapes ( & graph_mut. graph , infer_opts)
426427 . map_err ( OptimizeError :: InferShapesError ) ?;
427- let const_ids: Vec < Option < NodeId > > = infer_result
428+
429+ let sym_map = symbol_map ( & graph_mut. graph ) ;
430+
431+ // IDs of constants and value nodes that replace value IDs in the
432+ // input.
433+ //
434+ // Where shape inference infers that a value node has a fixed value,
435+ // it can be replaced with a constant. Where it infers the value
436+ // can be produced by evaluating a symbolic expression, replace the
437+ // value with the output of a `ComputeShape` node.
438+ let replacement_ids: Vec < Option < NodeId > > = infer_result
428439 . values
429440 . iter ( )
430441 . map ( |expr| {
431- let constant = expr. to_constant ( ) ?;
432- let tensor = match constant {
433- rten_shape_inference:: Constant :: Scalar ( x) => Tensor :: from ( x) ,
434- rten_shape_inference:: Constant :: Vector ( vec) => Tensor :: from ( vec) ,
435- } ;
436- let const_id = graph_mut. add_constant ( None , tensor. into_arc ( ) ) ;
437- Some ( const_id)
442+ if let Some ( constant) = expr. to_constant ( ) {
443+ let tensor = match constant {
444+ rten_shape_inference:: Constant :: Scalar ( x) => Tensor :: from ( x) ,
445+ rten_shape_inference:: Constant :: Vector ( vec) => Tensor :: from ( vec) ,
446+ } ;
447+ let const_id = graph_mut. add_constant ( None , tensor. into_arc ( ) ) ;
448+ Some ( const_id)
449+ } else if let Some ( values) = expr. values ( ) {
450+ let ( symbols, input_ids) = compute_shape_inputs ( values, & sym_map) ;
451+ let op = ComputeShape {
452+ symbols,
453+ elements : if let Some ( expr) = expr. as_vector ( ) {
454+ SymExprKind :: Vector ( expr. to_vec ( ) )
455+ } else {
456+ SymExprKind :: Scalar ( values[ 0 ] . clone ( ) )
457+ } ,
458+ } ;
459+ let input_ids: Vec < _ > = input_ids. into_iter ( ) . map ( Some ) . collect ( ) ;
460+ let output_id = graph_mut. graph . add_value ( None , None , None ) ;
461+ graph_mut. add_operator ( None , Arc :: new ( op) , & input_ids, & [ Some ( output_id) ] ) ;
462+ Some ( output_id)
463+ } else {
464+ None
465+ }
438466 } )
439467 . collect ( ) ;
440468
469+ let mut removed_nodes = Vec :: new ( ) ;
441470 for ( value_id, shape_index) in & infer_result. shapes {
442- if let Some ( const_id) = const_ids[ * shape_index] {
443- graph_mut. replace_value ( * value_id, const_id) ;
471+ if let Some ( new_value_id) = replacement_ids[ * shape_index] {
472+ graph_mut. replace_value ( * value_id, new_value_id) ;
473+ removed_nodes. push ( * value_id) ;
474+ if let Some ( ( src_op_id, _src_op) ) = graph_mut. graph . get_source_node ( * value_id) {
475+ removed_nodes. push ( src_op_id) ;
476+ }
444477 } else if let Some ( dims) = infer_result. dims ( * value_id) {
445478 graph_mut. graph . update_value_shape ( * value_id, dims) ;
446479 }
447480 }
448481 for ( value_id, value_type) in infer_result. types {
449482 graph_mut. graph . update_value_type ( value_id, value_type) ;
450483 }
484+
485+ graph_mut. graph . remove_nodes ( & removed_nodes) ;
451486 }
452487
453488 // "Early" fusions. These are fusions which can benefit constant
@@ -464,13 +499,6 @@ impl GraphOptimizer {
464499 early_fusions. push ( CastElimination { } ) ;
465500 early_fusions. push ( IdentityFusion { } ) ;
466501
467- // Fusion which replaces Shape nodes using shape inference metadata.
468- //
469- // This can free up the source of the Shape's input to be included in
470- // other fusions. If all dimensions have static sizes, constant prop
471- // will remove the ComputeShape node and downstream nodes.
472- early_fusions. push ( ComputeShapeFusion { } ) ;
473-
474502 self . apply_fusions ( & mut graph_mut, early_fusions. visitors ( ) , & diag) ?;
475503
476504 // Constant propagation.
@@ -736,5 +764,73 @@ impl FusionList {
736764 }
737765}
738766
767+ /// Create a map of dimension name to (value_id, dim), for use with
768+ /// [`compute_shape_inputs`].
769+ fn symbol_map ( graph : & Graph ) -> HashMap < String , ( NodeId , u32 ) > {
770+ let mut map = HashMap :: new ( ) ;
771+
772+ for id in graph. input_ids ( ) {
773+ let Some ( node) = graph. get_node ( * id) else {
774+ continue ;
775+ } ;
776+ let Some ( shape) = node. shape ( ) else {
777+ continue ;
778+ } ;
779+
780+ for ( dim_idx, dim) in shape. iter ( ) . enumerate ( ) {
781+ match dim {
782+ Dimension :: Symbolic ( name) => {
783+ if !map. contains_key ( name) {
784+ map. insert ( name. to_string ( ) , ( * id, dim_idx as u32 ) ) ;
785+ }
786+ }
787+ Dimension :: Fixed ( _) => { }
788+ }
789+ }
790+ }
791+
792+ map
793+ }
794+
795+ /// Generate the input ID list and symbol_name => (input_id, axis) mappings for
796+ /// a [`ComputeShape`] operator.
797+ fn compute_shape_inputs (
798+ elements : & [ SymExpr ] ,
799+ syms : & HashMap < String , ( NodeId , u32 ) > ,
800+ ) -> ( Vec < SymbolInfo > , Vec < NodeId > ) {
801+ let vars = elements
802+ . iter ( )
803+ . flat_map ( |expr| expr. iter ( ) )
804+ . filter_map ( |node| match node {
805+ SymExpr :: Var ( sym) => Some ( sym. name . as_ref ( ) ) ,
806+ _ => None ,
807+ } ) ;
808+
809+ let mut input_ids = Vec :: new ( ) ;
810+ let mut symbols = Vec :: < SymbolInfo > :: new ( ) ;
811+ for var in vars {
812+ if symbols. iter ( ) . any ( |s| s. name == var) {
813+ continue ;
814+ }
815+ let Some ( ( input_id, axis) ) = syms. get ( var) else {
816+ continue ;
817+ } ;
818+ let input_id = if let Some ( idx) = input_ids. iter ( ) . position ( |id| id == input_id) {
819+ idx
820+ } else {
821+ let idx = input_ids. len ( ) ;
822+ input_ids. push ( * input_id) ;
823+ idx
824+ } ;
825+ symbols. push ( SymbolInfo {
826+ name : var. to_string ( ) ,
827+ input : input_id as u32 ,
828+ axis : * axis,
829+ } ) ;
830+ }
831+
832+ ( symbols, input_ids)
833+ }
834+
739835#[ cfg( test) ]
740836mod tests;
0 commit comments