Skip to content

Commit e4e1532

Browse files
committed
Refactor execution runtime cleanup
1 parent bd0833c commit e4e1532

File tree

8 files changed

+312
-290
lines changed

8 files changed

+312
-290
lines changed

src/burn_runtime.rs

Lines changed: 22 additions & 30 deletions
Original file line numberDiff line numberDiff line change
@@ -3,7 +3,10 @@ use std::time::Instant;
33
use burn::prelude::Backend;
44

55
use crate::burn_model::BurnTransformerVm;
6-
use crate::engine::{ExecutionEngine, ExecutionResult, ExecutionTraceEntry};
6+
use crate::engine::{
7+
build_execution_result, execution_complete, record_execution_step, ExecutionEngine,
8+
ExecutionResult, ExecutionTraceEntry,
9+
};
710
use crate::error::Result;
811
use crate::instruction::Instruction;
912
use crate::memory::AddressedMemory;
@@ -26,10 +29,7 @@ impl<B: Backend> BurnExecutionRuntime<B> {
2629
pub fn new(model: BurnTransformerVm<B>, device: B::Device, max_steps: usize) -> Self {
2730
let initial_memory = model.program().initial_memory().to_vec();
2831
let memory = AddressedMemory::from_initial(&initial_memory);
29-
let state = MachineState {
30-
memory: initial_memory,
31-
..MachineState::new(model.program().memory_size())
32-
};
32+
let state = MachineState::with_memory(initial_memory);
3333

3434
Self {
3535
model,
@@ -44,7 +44,7 @@ impl<B: Backend> BurnExecutionRuntime<B> {
4444
}
4545

4646
pub fn step(&mut self) -> Result<&MachineState> {
47-
if self.state.halted || self.step_count >= self.max_steps {
47+
if execution_complete(&self.state, self.step_count, self.max_steps) {
4848
return Ok(&self.state);
4949
}
5050

@@ -58,36 +58,28 @@ impl<B: Backend> BurnExecutionRuntime<B> {
5858
)?;
5959
self.state = next;
6060
self.step_count += 1;
61-
self.trace.push(self.state.clone());
62-
self.events.push(ExecutionTraceEntry {
63-
step: self.step_count,
64-
layer_idx: Some(dispatch.layer_idx),
65-
instruction: dispatch.instruction,
66-
state_before: before,
67-
state_after: self.state.clone(),
68-
});
61+
record_execution_step(
62+
&mut self.trace,
63+
&mut self.events,
64+
self.step_count,
65+
Some(dispatch.layer_idx),
66+
dispatch.instruction,
67+
before,
68+
&self.state,
69+
);
6970
Ok(&self.state)
7071
}
7172

7273
pub fn run(&mut self) -> Result<ExecutionResult> {
7374
let start = Instant::now();
74-
while self.step_count < self.max_steps && !self.state.halted {
75+
while !execution_complete(&self.state, self.step_count, self.max_steps) {
7576
self.step()?;
7677
}
77-
78-
let elapsed = start.elapsed();
79-
let elapsed_secs = elapsed.as_secs_f64();
80-
Ok(ExecutionResult {
81-
final_state: self.state.clone(),
82-
steps: self.step_count,
83-
halted: self.state.halted,
84-
elapsed,
85-
tokens_per_sec: if elapsed_secs > 0.0 {
86-
self.step_count as f64 / elapsed_secs
87-
} else {
88-
0.0
89-
},
90-
})
78+
Ok(build_execution_result(
79+
&self.state,
80+
self.step_count,
81+
start.elapsed(),
82+
))
9183
}
9284

9385
pub fn trace(&self) -> &[MachineState] {
@@ -119,7 +111,7 @@ impl<B: Backend> BurnExecutionRuntime<B> {
119111
}
120112

121113
pub fn next_dispatch(&self) -> Result<Option<DispatchInfo>> {
122-
if self.state.halted || self.step_count >= self.max_steps {
114+
if execution_complete(&self.state, self.step_count, self.max_steps) {
123115
return Ok(None);
124116
}
125117
self.model.dispatch_info(self.state.pc).map(Some)

src/engine.rs

Lines changed: 52 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -45,6 +45,57 @@ pub trait ExecutionEngine {
4545
fn next_instruction(&self) -> Result<Option<Instruction>>;
4646

4747
fn is_halted(&self) -> bool {
48-
self.state().halted || self.step_count() >= self.max_steps()
48+
execution_complete(self.state(), self.step_count(), self.max_steps())
49+
}
50+
}
51+
52+
pub(crate) fn execution_complete(
53+
state: &MachineState,
54+
step_count: usize,
55+
max_steps: usize,
56+
) -> bool {
57+
state.halted || step_count >= max_steps
58+
}
59+
60+
pub(crate) fn build_execution_result(
61+
final_state: &MachineState,
62+
steps: usize,
63+
elapsed: Duration,
64+
) -> ExecutionResult {
65+
ExecutionResult {
66+
final_state: final_state.clone(),
67+
steps,
68+
halted: final_state.halted,
69+
elapsed,
70+
tokens_per_sec: tokens_per_sec(steps, elapsed),
71+
}
72+
}
73+
74+
pub(crate) fn record_execution_step(
75+
trace: &mut Vec<MachineState>,
76+
events: &mut Vec<ExecutionTraceEntry>,
77+
step: usize,
78+
layer_idx: Option<usize>,
79+
instruction: Instruction,
80+
state_before: MachineState,
81+
state_after: &MachineState,
82+
) {
83+
let state_after = state_after.clone();
84+
trace.push(state_after.clone());
85+
events.push(ExecutionTraceEntry {
86+
step,
87+
layer_idx,
88+
instruction,
89+
state_before,
90+
state_after,
91+
});
92+
}
93+
94+
fn tokens_per_sec(steps: usize, elapsed: Duration) -> f64 {
95+
let elapsed_secs = elapsed.as_secs_f64();
96+
if elapsed_secs > 0.0 {
97+
steps as f64 / elapsed_secs
98+
} else {
99+
0.0
49100
}
50101
}

0 commit comments

Comments
 (0)