@@ -8,10 +8,16 @@ use crate::{
88 tensor:: { AutodiffTensor , NodeRefCount } ,
99} ;
1010use alloc:: sync:: Arc ;
11- use burn_common :: stub :: Mutex ;
11+ use alloc :: vec :: Vec ;
1212use burn_tensor:: backend:: Backend ;
1313use hashbrown:: { HashMap , HashSet } ;
1414
15+ #[ cfg( feature = "std" ) ]
16+ use parking_lot:: { Mutex , MutexGuard } ;
17+
18+ #[ cfg( not( feature = "std" ) ) ]
19+ use spin:: { Mutex , MutexGuard } ;
20+
1521/// A client for managing multiple graphs using mutex-based synchronization.
1622///
1723/// The biggest benefit of using this client implementation is that each graph can modify its own
@@ -40,7 +46,6 @@ pub struct GraphLocator {
4046 /// This is to ensure that when merging graphs, we correctly move all previous graphs to
4147 /// the new merged one.
4248 keys : HashMap < NodeId , HashSet < NodeId > > ,
43- untracked_nodes : HashSet < NodeId > ,
4449}
4550
4651/// Represents a single computation graph with a mutex-protected server.
@@ -65,7 +70,7 @@ impl core::fmt::Debug for Graph {
6570 }
6671}
6772
68- static STATE : spin :: Mutex < Option < GraphLocator > > = spin :: Mutex :: new ( None ) ;
73+ static STATE : Mutex < Option < GraphLocator > > = Mutex :: new ( None ) ;
6974
7075impl GraphMutexClient {
7176 /// Retrieves or creates a graph for the given [NodeId] and parent dependencies.
@@ -95,52 +100,62 @@ impl AutodiffClient for GraphMutexClient {
95100 fn register ( & self , node_id_ref : NodeRefCount , step : StepBoxed , actions : CheckpointerBuilder ) {
96101 let node_id = * node_id_ref;
97102 let graph = GraphMutexClient :: graph ( node_id, step. parents ( ) ) ;
98- let mut state = graph. state . lock ( ) . unwrap ( ) ;
103+ let mut state = graph. state . lock ( ) ;
99104
100105 state. server . register ( node_id_ref, step, actions) ;
101106 }
102107
103- fn register_untracked (
104- & self ,
105- node_id_ref : NodeRefCount ,
106- step : StepBoxed ,
107- actions : CheckpointerBuilder ,
108- ) {
109- let node_id = * node_id_ref;
110- // Register normally (might be needed by tracked children, required for checkpointing)
111- self . register ( node_id_ref, step, actions) ;
112-
113- // But mark this node for cleanup after backward
114- let mut state = STATE . lock ( ) ;
115- if let Some ( locator) = state. as_mut ( ) {
116- locator. mark_untracked ( node_id) ;
117- }
118- }
119-
120108 fn backward < B : Backend > ( & self , root : AutodiffTensor < B > ) -> Gradients {
121109 let node_id = root. node . id ;
122110 let graph = GraphMutexClient :: graph ( root. node . id , & [ ] ) ;
123111
124112 let grads = Gradients :: new :: < B > ( root. node , root. primitive ) ;
125- let mut state = graph. state . lock ( ) . unwrap ( ) ;
126-
127- let grads = state. server . backward :: < GraphCleaner > ( grads, node_id) ;
113+ let grads = {
114+ let mut state = graph. state . lock ( ) ;
115+ state. server . backward :: < GraphCleaner > ( grads, node_id)
116+ } ; // lock released
128117
129- let mut cleaner = GraphCleaner :: init ( ) ;
130- cleaner. cleanup_orphaned_entries ( ) ;
118+ GraphCleaner :: cleanup_orphaned_entries ( ) ;
131119
132120 grads
133121 }
134122}
135123
136124struct GraphCleaner < ' a > {
137- guard : spin :: MutexGuard < ' a , Option < GraphLocator > > ,
125+ guard : MutexGuard < ' a , Option < GraphLocator > > ,
138126}
139127
140128impl < ' a > GraphCleaner < ' a > {
141- fn cleanup_orphaned_entries ( & mut self ) {
142- if let Some ( state) = self . guard . as_mut ( ) {
143- state. cleanup_untracked ( ) ;
129+ fn cleanup_orphaned_entries ( ) {
130+ let graphs = {
131+ // Get the available graphs and release the lock
132+ match STATE . lock ( ) . as_ref ( ) {
133+ Some ( state) => state. graphs . clone ( ) ,
134+ None => return ,
135+ }
136+ } ;
137+
138+ let mut should_remove = Vec :: new ( ) ;
139+ for graph in graphs. values ( ) {
140+ {
141+ let mut guard = graph. state . lock ( ) ;
142+ // Double safety: in case it was marked as no longer useful, but other
143+ // nodes are still relevant, we only check which nodes can safely be removed.
144+ if !guard. server . maybe_useful ( ) {
145+ guard
146+ . server
147+ . free_unused_roots ( |node| should_remove. push ( * node) ) ;
148+ }
149+ }
150+ }
151+
152+ if !should_remove. is_empty ( ) {
153+ let mut state = STATE . lock ( ) ;
154+ if let Some ( state) = state. as_mut ( ) {
155+ for node in should_remove {
156+ state. graphs . remove ( & node) ;
157+ }
158+ }
144159 }
145160 }
146161}
@@ -230,7 +245,7 @@ impl GraphLocator {
230245 let main = graphs. next ( ) . expect ( "At least one graph" ) ;
231246 self . register_key ( main. origin , node) ;
232247
233- let mut state = main. state . lock ( ) . unwrap ( ) ;
248+ let mut state = main. state . lock ( ) ;
234249
235250 for graph in graphs {
236251 self . merge_two ( & mut state, & main, graph) ;
@@ -259,7 +274,7 @@ impl GraphLocator {
259274
260275 /// Merges two graphs by combining their states and updating graph mappings.
261276 fn merge_two ( & mut self , main_state : & mut GraphState , main : & Arc < Graph > , merged : Arc < Graph > ) {
262- let mut locked = merged. state . lock ( ) . unwrap ( ) ;
277+ let mut locked = merged. state . lock ( ) ;
263278 let mut state_old = GraphState :: default ( ) ;
264279 core:: mem:: swap ( & mut state_old, & mut locked) ;
265280 main_state. server . extend ( state_old. server ) ;
@@ -292,18 +307,6 @@ impl GraphLocator {
292307 graph
293308 }
294309
295- fn mark_untracked ( & mut self , node_id : NodeId ) {
296- self . untracked_nodes . insert ( node_id) ;
297- }
298-
299- /// Clean up untracked nodes.
300- fn cleanup_untracked ( & mut self ) {
301- let mut nodes = core:: mem:: take ( & mut self . untracked_nodes ) ;
302- for node in nodes. drain ( ) {
303- self . remove_entry ( & node) ;
304- }
305- }
306-
307310 fn remove_entry ( & mut self , node : & NodeId ) {
308311 if let Some ( graph) = self . graphs . remove ( node) {
309312 let mut remove = false ;
0 commit comments