73
73
# This dictionary holds a mapping {graph: learning_phase}.
74
74
# A learning phase is a bool tensor used to run Keras models in
75
75
# either train mode (learning_phase == 1) or test mode (learning_phase == 0).
76
- _GRAPH_LEARNING_PHASES = {}
76
+ _GRAPH_LEARNING_PHASES = weakref .WeakKeyDictionary ()
77
+
78
+
79
+ # _DUMMY_EAGER_GRAPH is used as a key in _GRAPH_LEARNING_PHASES.
80
+ # We keep a separate reference to it to make sure it does not get removed from
81
+ # _GRAPH_LEARNING_PHASES. We use a dummy class instead of something like a
82
+ # string because strings are not weakly-referencable.
83
+ class _DummyEagerGraph (object ):
84
+ pass
85
+ _DUMMY_EAGER_GRAPH = _DummyEagerGraph ()
77
86
78
87
# This boolean flag can be set to True to leave variable initialization
79
88
# up to the user.
96
105
97
106
# This dictionary holds a mapping between a graph and variables to initialize
98
107
# in the graph.
99
- _GRAPH_VARIABLES = {}
108
+ _GRAPH_VARIABLES = weakref . WeakKeyDictionary ()
100
109
101
110
# This dictionary holds a mapping between a graph and TF optimizers created in
102
111
# the graph.
103
- _GRAPH_TF_OPTIMIZERS = {}
112
+ _GRAPH_TF_OPTIMIZERS = weakref . WeakKeyDictionary ()
104
113
105
114
106
115
@tf_export ('keras.backend.backend' )
@@ -359,10 +368,10 @@ def learning_phase():
359
368
Learning phase (scalar integer tensor or Python integer).
360
369
"""
361
370
if context .executing_eagerly ():
362
- if 'eager' not in _GRAPH_LEARNING_PHASES :
371
+ if _DUMMY_EAGER_GRAPH not in _GRAPH_LEARNING_PHASES :
363
372
# Fallback to inference mode as default.
364
373
return 0
365
- return _GRAPH_LEARNING_PHASES ['eager' ]
374
+ return _GRAPH_LEARNING_PHASES [_DUMMY_EAGER_GRAPH ]
366
375
367
376
graph = ops .get_default_graph ()
368
377
if graph not in _GRAPH_LEARNING_PHASES :
@@ -386,7 +395,7 @@ def set_learning_phase(value):
386
395
if value not in {0 , 1 }:
387
396
raise ValueError ('Expected learning phase to be 0 or 1.' )
388
397
if context .executing_eagerly ():
389
- _GRAPH_LEARNING_PHASES ['eager' ] = value
398
+ _GRAPH_LEARNING_PHASES [_DUMMY_EAGER_GRAPH ] = value
390
399
else :
391
400
_GRAPH_LEARNING_PHASES [ops .get_default_graph ()] = value
392
401
@@ -415,7 +424,7 @@ def learning_phase_scope(value):
415
424
finally :
416
425
# Restore learning phase to initial value.
417
426
if context .executing_eagerly ():
418
- _GRAPH_LEARNING_PHASES ['eager' ] = previous_value
427
+ _GRAPH_LEARNING_PHASES [_DUMMY_EAGER_GRAPH ] = previous_value
419
428
else :
420
429
_GRAPH_LEARNING_PHASES [ops .get_default_graph ()] = previous_value
421
430
@@ -683,14 +692,14 @@ def track_variable(v):
683
692
return
684
693
graph = v .graph if hasattr (v , 'graph' ) else ops .get_default_graph ()
685
694
if graph not in _GRAPH_VARIABLES :
686
- _GRAPH_VARIABLES [graph ] = set ()
695
+ _GRAPH_VARIABLES [graph ] = weakref . WeakSet ()
687
696
_GRAPH_VARIABLES [graph ].add (v )
688
697
689
698
690
699
def _get_variables (graph = None ):
691
700
"""Returns variables corresponding to the given graph for initialization."""
692
701
assert not context .executing_eagerly ()
693
- variables = _GRAPH_VARIABLES .get (graph , set ())
702
+ variables = _GRAPH_VARIABLES .setdefault (graph , weakref . WeakSet ())
694
703
for opt in _GRAPH_TF_OPTIMIZERS .get (graph , set ()):
695
704
variables .update (opt .optimizer .variables ())
696
705
return variables
0 commit comments