Skip to content

Commit 7a54250

Browse files
asimshankargunan
authored andcommitted
Merge some memory leak fixing changes from master to r1.11 (tensorflow#22404)
1 parent e4c4b20 commit 7a54250

File tree

1 file changed

+18
-9
lines changed

1 file changed

+18
-9
lines changed

tensorflow/python/keras/backend.py

Lines changed: 18 additions & 9 deletions
Original file line numberDiff line numberDiff line change
@@ -73,7 +73,16 @@
7373
# This dictionary holds a mapping {graph: learning_phase}.
7474
# A learning phase is a bool tensor used to run Keras models in
7575
# either train mode (learning_phase == 1) or test mode (learning_phase == 0).
76-
_GRAPH_LEARNING_PHASES = {}
76+
_GRAPH_LEARNING_PHASES = weakref.WeakKeyDictionary()
77+
78+
79+
# _DUMMY_EAGER_GRAPH is used as a key in _GRAPH_LEARNING_PHASES.
80+
# We keep a separate reference to it to make sure it does not get removed from
81+
# _GRAPH_LEARNING_PHASES. We use a dummy class instead of something like a
82+
# string because strings are not weakly-referencable.
83+
class _DummyEagerGraph(object):
84+
pass
85+
_DUMMY_EAGER_GRAPH = _DummyEagerGraph()
7786

7887
# This boolean flag can be set to True to leave variable initialization
7988
# up to the user.
@@ -96,11 +105,11 @@
96105

97106
# This dictionary holds a mapping between a graph and variables to initialize
98107
# in the graph.
99-
_GRAPH_VARIABLES = {}
108+
_GRAPH_VARIABLES = weakref.WeakKeyDictionary()
100109

101110
# This dictionary holds a mapping between a graph and TF optimizers created in
102111
# the graph.
103-
_GRAPH_TF_OPTIMIZERS = {}
112+
_GRAPH_TF_OPTIMIZERS = weakref.WeakKeyDictionary()
104113

105114

106115
@tf_export('keras.backend.backend')
@@ -359,10 +368,10 @@ def learning_phase():
359368
Learning phase (scalar integer tensor or Python integer).
360369
"""
361370
if context.executing_eagerly():
362-
if 'eager' not in _GRAPH_LEARNING_PHASES:
371+
if _DUMMY_EAGER_GRAPH not in _GRAPH_LEARNING_PHASES:
363372
# Fallback to inference mode as default.
364373
return 0
365-
return _GRAPH_LEARNING_PHASES['eager']
374+
return _GRAPH_LEARNING_PHASES[_DUMMY_EAGER_GRAPH]
366375

367376
graph = ops.get_default_graph()
368377
if graph not in _GRAPH_LEARNING_PHASES:
@@ -386,7 +395,7 @@ def set_learning_phase(value):
386395
if value not in {0, 1}:
387396
raise ValueError('Expected learning phase to be 0 or 1.')
388397
if context.executing_eagerly():
389-
_GRAPH_LEARNING_PHASES['eager'] = value
398+
_GRAPH_LEARNING_PHASES[_DUMMY_EAGER_GRAPH] = value
390399
else:
391400
_GRAPH_LEARNING_PHASES[ops.get_default_graph()] = value
392401

@@ -415,7 +424,7 @@ def learning_phase_scope(value):
415424
finally:
416425
# Restore learning phase to initial value.
417426
if context.executing_eagerly():
418-
_GRAPH_LEARNING_PHASES['eager'] = previous_value
427+
_GRAPH_LEARNING_PHASES[_DUMMY_EAGER_GRAPH] = previous_value
419428
else:
420429
_GRAPH_LEARNING_PHASES[ops.get_default_graph()] = previous_value
421430

@@ -683,14 +692,14 @@ def track_variable(v):
683692
return
684693
graph = v.graph if hasattr(v, 'graph') else ops.get_default_graph()
685694
if graph not in _GRAPH_VARIABLES:
686-
_GRAPH_VARIABLES[graph] = set()
695+
_GRAPH_VARIABLES[graph] = weakref.WeakSet()
687696
_GRAPH_VARIABLES[graph].add(v)
688697

689698

690699
def _get_variables(graph=None):
691700
"""Returns variables corresponding to the given graph for initialization."""
692701
assert not context.executing_eagerly()
693-
variables = _GRAPH_VARIABLES.get(graph, set())
702+
variables = _GRAPH_VARIABLES.setdefault(graph, weakref.WeakSet())
694703
for opt in _GRAPH_TF_OPTIMIZERS.get(graph, set()):
695704
variables.update(opt.optimizer.variables())
696705
return variables

0 commit comments

Comments
 (0)