@@ -3,7 +3,10 @@ use std::time::Instant;
33use burn:: prelude:: Backend ;
44
55use crate :: burn_model:: BurnTransformerVm ;
6- use crate :: engine:: { ExecutionEngine , ExecutionResult , ExecutionTraceEntry } ;
6+ use crate :: engine:: {
7+ build_execution_result, execution_complete, record_execution_step, ExecutionEngine ,
8+ ExecutionResult , ExecutionTraceEntry ,
9+ } ;
710use crate :: error:: Result ;
811use crate :: instruction:: Instruction ;
912use crate :: memory:: AddressedMemory ;
@@ -26,10 +29,7 @@ impl<B: Backend> BurnExecutionRuntime<B> {
2629 pub fn new ( model : BurnTransformerVm < B > , device : B :: Device , max_steps : usize ) -> Self {
2730 let initial_memory = model. program ( ) . initial_memory ( ) . to_vec ( ) ;
2831 let memory = AddressedMemory :: from_initial ( & initial_memory) ;
29- let state = MachineState {
30- memory : initial_memory,
31- ..MachineState :: new ( model. program ( ) . memory_size ( ) )
32- } ;
32+ let state = MachineState :: with_memory ( initial_memory) ;
3333
3434 Self {
3535 model,
@@ -44,7 +44,7 @@ impl<B: Backend> BurnExecutionRuntime<B> {
4444 }
4545
4646 pub fn step ( & mut self ) -> Result < & MachineState > {
47- if self . state . halted || self . step_count >= self . max_steps {
47+ if execution_complete ( & self . state , self . step_count , self . max_steps ) {
4848 return Ok ( & self . state ) ;
4949 }
5050
@@ -58,36 +58,28 @@ impl<B: Backend> BurnExecutionRuntime<B> {
5858 ) ?;
5959 self . state = next;
6060 self . step_count += 1 ;
61- self . trace . push ( self . state . clone ( ) ) ;
62- self . events . push ( ExecutionTraceEntry {
63- step : self . step_count ,
64- layer_idx : Some ( dispatch. layer_idx ) ,
65- instruction : dispatch. instruction ,
66- state_before : before,
67- state_after : self . state . clone ( ) ,
68- } ) ;
61+ record_execution_step (
62+ & mut self . trace ,
63+ & mut self . events ,
64+ self . step_count ,
65+ Some ( dispatch. layer_idx ) ,
66+ dispatch. instruction ,
67+ before,
68+ & self . state ,
69+ ) ;
6970 Ok ( & self . state )
7071 }
7172
7273 pub fn run ( & mut self ) -> Result < ExecutionResult > {
7374 let start = Instant :: now ( ) ;
74- while self . step_count < self . max_steps && ! self . state . halted {
75+ while ! execution_complete ( & self . state , self . step_count , self . max_steps ) {
7576 self . step ( ) ?;
7677 }
77-
78- let elapsed = start. elapsed ( ) ;
79- let elapsed_secs = elapsed. as_secs_f64 ( ) ;
80- Ok ( ExecutionResult {
81- final_state : self . state . clone ( ) ,
82- steps : self . step_count ,
83- halted : self . state . halted ,
84- elapsed,
85- tokens_per_sec : if elapsed_secs > 0.0 {
86- self . step_count as f64 / elapsed_secs
87- } else {
88- 0.0
89- } ,
90- } )
78+ Ok ( build_execution_result (
79+ & self . state ,
80+ self . step_count ,
81+ start. elapsed ( ) ,
82+ ) )
9183 }
9284
9385 pub fn trace ( & self ) -> & [ MachineState ] {
@@ -119,7 +111,7 @@ impl<B: Backend> BurnExecutionRuntime<B> {
119111 }
120112
121113 pub fn next_dispatch ( & self ) -> Result < Option < DispatchInfo > > {
122- if self . state . halted || self . step_count >= self . max_steps {
114+ if execution_complete ( & self . state , self . step_count , self . max_steps ) {
123115 return Ok ( None ) ;
124116 }
125117 self . model . dispatch_info ( self . state . pc ) . map ( Some )
0 commit comments