Skip to content

Commit 7131166

Browse files
gehringericl
authored andcommitted
[rllib] Tracing for eager tensorflow policies with tf.function (#5705)
* Added tracing of eager policies with `tf.function` * lint * add config option * add docs * wip * tracing now works with a3c * typo * none * file doc * returns * syntax error * syntax error
1 parent d1e4b36 commit 7131166

File tree

11 files changed

+204
-37
lines changed

11 files changed

+204
-37
lines changed

doc/source/rllib-concepts.rst

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -418,9 +418,9 @@ Finally, note that you do not have to use ``build_tf_policy`` to define a Tensor
418418
Building Policies in TensorFlow Eager
419419
~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~
420420

421-
Policies built with ``build_tf_policy`` (most of the reference algorithms are) can be run in eager mode by setting the ``"eager": True`` config option or using ``rllib train --eager``. This will tell RLlib to execute the model forward pass, action distribution, loss, and stats functions in eager mode.
421+
Policies built with ``build_tf_policy`` (most of the reference algorithms are) can be run in eager mode by setting the ``"eager": True`` / ``"eager_tracing": True`` config options or using ``rllib train --eager [--trace]``. This will tell RLlib to execute the model forward pass, action distribution, loss, and stats functions in eager mode.
422422

423-
Eager mode makes debugging much easier, since you can now use normal Python functions such as ``print()`` to inspect intermediate tensor values. However, it is slower than graph mode.
423+
Eager mode makes debugging much easier, since you can now use normal Python functions such as ``print()`` to inspect intermediate tensor values. However, it can be slower than graph mode unless tracing is enabled.
424424

425425
You can also selectively leverage eager operations within graph mode execution with `tf.py_function <https://www.tensorflow.org/api_docs/python/tf/py_function>`__. Here's an example of using eager ops embedded `within a loss function <https://github.com/ray-project/ray/blob/master/rllib/examples/eager_execution.py>`__.
426426

doc/source/rllib-training.rst

Lines changed: 3 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -14,7 +14,7 @@ You can train a simple DQN trainer with the following command:
1414

1515
.. code-block:: bash
1616
17-
rllib train --run DQN --env CartPole-v0 # add --eager for eager execution
17+
rllib train --run DQN --env CartPole-v0 # --eager [--trace] for eager execution
1818
1919
By default, the results will be logged to a subdirectory of ``~/ray_results``.
2020
This subdirectory will contain a file ``params.json`` which contains the
@@ -544,9 +544,9 @@ The ``"monitor": true`` config can be used to save Gym episode videos to the res
544544
Eager Mode
545545
~~~~~~~~~~
546546

547-
Policies built with ``build_tf_policy`` can be also run in eager mode by setting the ``"eager": True`` config option or using ``rllib train --eager``. This will tell RLlib to execute the model forward pass, action distribution, loss, and stats functions in eager mode.
547+
Policies built with ``build_tf_policy`` (most of the reference algorithms are) can be run in eager mode by setting the ``"eager": True`` / ``"eager_tracing": True`` config options or using ``rllib train --eager [--trace]``. This will tell RLlib to execute the model forward pass, action distribution, loss, and stats functions in eager mode.
548548

549-
Eager mode makes debugging much easier, since you can now use normal Python functions such as ``print()`` to inspect intermediate tensor values. However, it is slower than graph mode.
549+
Eager mode makes debugging much easier, since you can now use normal Python functions such as ``print()`` to inspect intermediate tensor values. However, it can be slower than graph mode unless tracing is enabled.
550550

551551
Episode Traces
552552
~~~~~~~~~~~~~~

doc/source/rllib.rst

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -25,7 +25,7 @@ Then, you can try out training in the following equivalent ways:
2525

2626
.. code-block:: bash
2727
28-
rllib train --run=PPO --env=CartPole-v0 # add --eager for eager execution
28+
rllib train --run=PPO --env=CartPole-v0 # --eager [--trace] for eager execution
2929
3030
.. code-block:: python
3131

rllib/agents/trainer.py

Lines changed: 7 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -70,8 +70,12 @@
7070
"ignore_worker_failures": False,
7171
# Log system resource metrics to results.
7272
"log_sys_usage": True,
73-
# Enable TF eager execution (TF policies only)
73+
# Enable TF eager execution (TF policies only).
7474
"eager": False,
75+
# Enable tracing in eager mode. This greatly improves performance, but
76+
# makes it slightly harder to debug since Python code won't be evaluated
77+
# after the initial eager pass.
78+
"eager_tracing": False,
7579
# Disable eager execution on workers (but allow it on the driver). This
7680
# only has an effect is eager is enabled.
7781
"no_eager_on_workers": False,
@@ -333,7 +337,8 @@ def __init__(self, config=None, env=None, logger_creator=None):
333337

334338
if tf and config.get("eager"):
335339
tf.enable_eager_execution()
336-
logger.info("Executing eagerly")
340+
logger.info("Executing eagerly, with eager_tracing={}".format(
341+
"True" if config.get("eager_tracing") else "False"))
337342

338343
if tf and not tf.executing_eagerly():
339344
logger.info("Tip: set 'eager': true or the --eager flag to enable "

rllib/evaluation/rollout_worker.py

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -752,6 +752,8 @@ def _build_policy_map(self, policy_dict, policy_config):
752752
if tf and tf.executing_eagerly():
753753
if hasattr(cls, "as_eager"):
754754
cls = cls.as_eager()
755+
if policy_config["eager_tracing"]:
756+
cls = cls.with_tracing()
755757
elif not issubclass(cls, TFPolicy):
756758
pass # could be some other type of policy
757759
else:

rllib/examples/custom_tf_policy.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -21,14 +21,14 @@ def policy_gradient_loss(policy, model, dist_class, train_batch):
2121
logits, _ = model.from_batch(train_batch)
2222
action_dist = dist_class(logits, model)
2323
return -tf.reduce_mean(
24-
action_dist.logp(train_batch["actions"]) * train_batch["advantages"])
24+
action_dist.logp(train_batch["actions"]) * train_batch["returns"])
2525

2626

2727
def calculate_advantages(policy,
2828
sample_batch,
2929
other_agent_batches=None,
3030
episode=None):
31-
sample_batch["advantages"] = discount(sample_batch["rewards"], 0.99)
31+
sample_batch["returns"] = discount(sample_batch["rewards"], 0.99)
3232
return sample_batch
3333

3434

rllib/policy/dynamic_tf_policy.py

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1,6 +1,7 @@
11
from __future__ import absolute_import
22
from __future__ import division
33
from __future__ import print_function
4+
"""Graph mode TF policy built using build_tf_policy()."""
45

56
from collections import OrderedDict
67
import logging

0 commit comments

Comments
 (0)