Skip to content

Commit 7385a3a

Browse files
authored
Cleanup autodiff unused roots (tracel-ai#4039)
1 parent 0993b00 commit 7385a3a

File tree

8 files changed

+84
-68
lines changed

8 files changed

+84
-68
lines changed

Cargo.lock

Lines changed: 1 addition & 0 deletions
Some generated files are not rendered by default. Learn more about customizing how changed files appear on GitHub.

Cargo.toml

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -64,6 +64,7 @@ indicatif = "0.18.0"
6464
js-sys = "0.3.77"
6565
libm = "0.2.15"
6666
log = { default-features = false, version = "0.4.28" }
67+
parking_lot = { version = "0.12.5", default-features = false }
6768
paste = "1"
6869
planus = { version = "=1.1" }
6970
polars = { version = "0.51.0", features = ["lazy"] }

crates/burn-autodiff/Cargo.toml

Lines changed: 3 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -17,7 +17,7 @@ workspace = true
1717
[features]
1818
default = ["std"]
1919
export_tests = ["burn-tensor-testgen"]
20-
std = []
20+
std = ["dep:parking_lot"]
2121

2222
[dependencies]
2323
burn-common = { path = "../burn-common", version = "=0.20.0-pre.2", default-features = false }
@@ -26,11 +26,13 @@ burn-tensor-testgen = { path = "../burn-tensor-testgen", version = "=0.20.0-pre.
2626

2727
derive-new = { workspace = true }
2828
spin = { workspace = true }
29+
parking_lot = { workspace = true, optional = true }
2930
log = { workspace = true }
3031
hashbrown = { workspace = true }
3132
num-traits = { workspace = true }
3233
portable-atomic = { workspace = true }
3334

35+
3436
[dev-dependencies]
3537
burn-tensor = { path = "../burn-tensor", version = "=0.20.0-pre.2", default-features = false, features = [
3638
"export_tests",

crates/burn-autodiff/src/runtime/client.rs

Lines changed: 0 additions & 9 deletions
Original file line numberDiff line numberDiff line change
@@ -10,15 +10,6 @@ use burn_tensor::backend::Backend;
1010
pub trait AutodiffClient: Send + Clone {
1111
/// Register a new step.
1212
fn register(&self, node_id: NodeRefCount, step: StepBoxed, actions: CheckpointerBuilder);
13-
/// Register a new untracked step.
14-
fn register_untracked(
15-
&self,
16-
node_id: NodeRefCount,
17-
step: StepBoxed,
18-
actions: CheckpointerBuilder,
19-
) {
20-
self.register(node_id, step, actions);
21-
}
2213
/// Call backpropagation from the given tensor.
2314
fn backward<B: Backend>(&self, tensor: AutodiffTensor<B>) -> Gradients;
2415
}

crates/burn-autodiff/src/runtime/graph.rs

Lines changed: 47 additions & 44 deletions
Original file line numberDiff line numberDiff line change
@@ -8,10 +8,16 @@ use crate::{
88
tensor::{AutodiffTensor, NodeRefCount},
99
};
1010
use alloc::sync::Arc;
11-
use burn_common::stub::Mutex;
11+
use alloc::vec::Vec;
1212
use burn_tensor::backend::Backend;
1313
use hashbrown::{HashMap, HashSet};
1414

15+
#[cfg(feature = "std")]
16+
use parking_lot::{Mutex, MutexGuard};
17+
18+
#[cfg(not(feature = "std"))]
19+
use spin::{Mutex, MutexGuard};
20+
1521
/// A client for managing multiple graphs using mutex-based synchronization.
1622
///
1723
/// The biggest benefit of using this client implementation is that each graph can modify its own
@@ -40,7 +46,6 @@ pub struct GraphLocator {
4046
/// This is to ensure that when merging graphs, we correctly move all previous graphs to
4147
/// the new merged one.
4248
keys: HashMap<NodeId, HashSet<NodeId>>,
43-
untracked_nodes: HashSet<NodeId>,
4449
}
4550

4651
/// Represents a single computation graph with a mutex-protected server.
@@ -65,7 +70,7 @@ impl core::fmt::Debug for Graph {
6570
}
6671
}
6772

68-
static STATE: spin::Mutex<Option<GraphLocator>> = spin::Mutex::new(None);
73+
static STATE: Mutex<Option<GraphLocator>> = Mutex::new(None);
6974

7075
impl GraphMutexClient {
7176
/// Retrieves or creates a graph for the given [NodeId] and parent dependencies.
@@ -95,52 +100,62 @@ impl AutodiffClient for GraphMutexClient {
95100
fn register(&self, node_id_ref: NodeRefCount, step: StepBoxed, actions: CheckpointerBuilder) {
96101
let node_id = *node_id_ref;
97102
let graph = GraphMutexClient::graph(node_id, step.parents());
98-
let mut state = graph.state.lock().unwrap();
103+
let mut state = graph.state.lock();
99104

100105
state.server.register(node_id_ref, step, actions);
101106
}
102107

103-
fn register_untracked(
104-
&self,
105-
node_id_ref: NodeRefCount,
106-
step: StepBoxed,
107-
actions: CheckpointerBuilder,
108-
) {
109-
let node_id = *node_id_ref;
110-
// Register normally (might be needed by tracked children, required for checkpointing)
111-
self.register(node_id_ref, step, actions);
112-
113-
// But mark this node for cleanup after backward
114-
let mut state = STATE.lock();
115-
if let Some(locator) = state.as_mut() {
116-
locator.mark_untracked(node_id);
117-
}
118-
}
119-
120108
fn backward<B: Backend>(&self, root: AutodiffTensor<B>) -> Gradients {
121109
let node_id = root.node.id;
122110
let graph = GraphMutexClient::graph(root.node.id, &[]);
123111

124112
let grads = Gradients::new::<B>(root.node, root.primitive);
125-
let mut state = graph.state.lock().unwrap();
126-
127-
let grads = state.server.backward::<GraphCleaner>(grads, node_id);
113+
let grads = {
114+
let mut state = graph.state.lock();
115+
state.server.backward::<GraphCleaner>(grads, node_id)
116+
}; // lock released
128117

129-
let mut cleaner = GraphCleaner::init();
130-
cleaner.cleanup_orphaned_entries();
118+
GraphCleaner::cleanup_orphaned_entries();
131119

132120
grads
133121
}
134122
}
135123

136124
struct GraphCleaner<'a> {
137-
guard: spin::MutexGuard<'a, Option<GraphLocator>>,
125+
guard: MutexGuard<'a, Option<GraphLocator>>,
138126
}
139127

140128
impl<'a> GraphCleaner<'a> {
141-
fn cleanup_orphaned_entries(&mut self) {
142-
if let Some(state) = self.guard.as_mut() {
143-
state.cleanup_untracked();
129+
fn cleanup_orphaned_entries() {
130+
let graphs = {
131+
// Get the available graphs and release the lock
132+
match STATE.lock().as_ref() {
133+
Some(state) => state.graphs.clone(),
134+
None => return,
135+
}
136+
};
137+
138+
let mut should_remove = Vec::new();
139+
for graph in graphs.values() {
140+
{
141+
let mut guard = graph.state.lock();
142+
// Double safety: in case it was marked as no longer useful, but other
143+
// nodes are still relevant, we only check which nodes can safely be removed.
144+
if !guard.server.maybe_useful() {
145+
guard
146+
.server
147+
.free_unused_roots(|node| should_remove.push(*node));
148+
}
149+
}
150+
}
151+
152+
if !should_remove.is_empty() {
153+
let mut state = STATE.lock();
154+
if let Some(state) = state.as_mut() {
155+
for node in should_remove {
156+
state.graphs.remove(&node);
157+
}
158+
}
144159
}
145160
}
146161
}
@@ -230,7 +245,7 @@ impl GraphLocator {
230245
let main = graphs.next().expect("At least one graph");
231246
self.register_key(main.origin, node);
232247

233-
let mut state = main.state.lock().unwrap();
248+
let mut state = main.state.lock();
234249

235250
for graph in graphs {
236251
self.merge_two(&mut state, &main, graph);
@@ -259,7 +274,7 @@ impl GraphLocator {
259274

260275
/// Merges two graphs by combining their states and updating graph mappings.
261276
fn merge_two(&mut self, main_state: &mut GraphState, main: &Arc<Graph>, merged: Arc<Graph>) {
262-
let mut locked = merged.state.lock().unwrap();
277+
let mut locked = merged.state.lock();
263278
let mut state_old = GraphState::default();
264279
core::mem::swap(&mut state_old, &mut locked);
265280
main_state.server.extend(state_old.server);
@@ -292,18 +307,6 @@ impl GraphLocator {
292307
graph
293308
}
294309

295-
fn mark_untracked(&mut self, node_id: NodeId) {
296-
self.untracked_nodes.insert(node_id);
297-
}
298-
299-
/// Clean up untracked nodes.
300-
fn cleanup_untracked(&mut self) {
301-
let mut nodes = core::mem::take(&mut self.untracked_nodes);
302-
for node in nodes.drain() {
303-
self.remove_entry(&node);
304-
}
305-
}
306-
307310
fn remove_entry(&mut self, node: &NodeId) {
308311
if let Some(graph) = self.graphs.remove(node) {
309312
let mut remove = false;

crates/burn-autodiff/src/runtime/memory_management.rs

Lines changed: 15 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -89,7 +89,17 @@ impl GraphMemoryManagement {
8989
}
9090
}
9191

92-
fn clear_unused_roots(&mut self, to_delete: &mut Vec<NodeId>) {
92+
pub(crate) fn free_unused_roots(&mut self, mut on_free_graph: impl FnMut(&NodeId)) {
93+
let mut deletables = Vec::new();
94+
self.clear_unused_roots(&mut deletables);
95+
96+
for node_id in deletables {
97+
self.nodes.remove(&node_id);
98+
on_free_graph(&node_id);
99+
}
100+
}
101+
102+
fn clear_unused_roots(&self, to_delete: &mut Vec<NodeId>) {
93103
for (id, parents) in self.nodes.iter() {
94104
let is_useful = matches!(
95105
self.statuses.get(id.as_ref()),
@@ -250,6 +260,10 @@ impl GraphMemoryManagement {
250260
None => panic!("Node should be in the nodes map"),
251261
}
252262
}
263+
264+
pub(crate) fn maybe_useful(&self) -> bool {
265+
self.nodes.keys().any(|node| Arc::strong_count(node) > 1)
266+
}
253267
}
254268

255269
/// Wrapper over hash set for fast popping of any node

crates/burn-autodiff/src/runtime/server.rs

Lines changed: 12 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -71,6 +71,14 @@ impl AutodiffServer {
7171
gradients
7272
}
7373

74+
pub(crate) fn free_unused_roots(&mut self, mut on_free_graph: impl FnMut(&NodeId)) {
75+
self.memory_management.free_unused_roots(|node_id| {
76+
self.steps.remove(node_id);
77+
self.actions_builder.remove(node_id);
78+
on_free_graph(node_id);
79+
});
80+
}
81+
7482
fn build_tape(
7583
&mut self,
7684
node: NodeId,
@@ -128,4 +136,8 @@ impl AutodiffServer {
128136

129137
grads
130138
}
139+
140+
pub(crate) fn maybe_useful(&self) -> bool {
141+
self.memory_management.maybe_useful()
142+
}
131143
}

crates/burn-autodiff/src/tensor.rs

Lines changed: 5 additions & 13 deletions
Original file line numberDiff line numberDiff line change
@@ -156,19 +156,11 @@ impl<B: Backend> AutodiffTensor<B> {
156156
step_that_created_the_tensor: S,
157157
actions: CheckpointerBuilder,
158158
) -> Self {
159-
if self.is_tracked() {
160-
self.node.client.register(
161-
self.rc.clone(),
162-
Box::new(step_that_created_the_tensor),
163-
actions,
164-
);
165-
} else {
166-
self.node.client.register_untracked(
167-
self.rc.clone(),
168-
Box::new(step_that_created_the_tensor),
169-
actions,
170-
);
171-
}
159+
self.node.client.register(
160+
self.rc.clone(),
161+
Box::new(step_that_created_the_tensor),
162+
actions,
163+
);
172164
self
173165
}
174166

0 commit comments

Comments
 (0)