From 4c3104e3fe561a1836d4699a11799ee48b1d86c1 Mon Sep 17 00:00:00 2001 From: Matthew Date: Sun, 19 Oct 2025 20:08:16 -0500 Subject: [PATCH 1/6] Fixed value bootstrapping bug for PPO and added a test to detect future regressions. Signed-off-by: Matthew --- rllib/BUILD.bazel | 10 +++ .../ppo/tests/test_ppo_value_bootstrapping.py | 88 +++++++++++++++++++ .../add_one_ts_to_episodes_and_truncate.py | 6 +- 3 files changed, 101 insertions(+), 3 deletions(-) create mode 100644 rllib/algorithms/ppo/tests/test_ppo_value_bootstrapping.py diff --git a/rllib/BUILD.bazel b/rllib/BUILD.bazel index 4dbfccb6c865..061fad74557a 100644 --- a/rllib/BUILD.bazel +++ b/rllib/BUILD.bazel @@ -2299,6 +2299,16 @@ py_test( ], ) +py_test( + name = "test_ppo_value_bootstrapping", + size = "medium", + srcs = ["algorithms/ppo/tests/test_ppo_value_bootstrapping.py"], + tags = [ + "algorithms_dir", + "team:rllib", + ], +) + # SAC py_test( name = "test_sac", diff --git a/rllib/algorithms/ppo/tests/test_ppo_value_bootstrapping.py b/rllib/algorithms/ppo/tests/test_ppo_value_bootstrapping.py new file mode 100644 index 000000000000..15479fcfa8e7 --- /dev/null +++ b/rllib/algorithms/ppo/tests/test_ppo_value_bootstrapping.py @@ -0,0 +1,88 @@ +import unittest + +import ray +import ray.rllib.algorithms.ppo as ppo +from ray.rllib.connectors.env_to_module import FlattenObservations +from ray.rllib.core.columns import Columns +from ray.rllib.policy.sample_batch import DEFAULT_POLICY_ID +from ray.rllib.utils.metrics import ENV_RUNNER_RESULTS, EPISODE_RETURN_MEAN +import torch + +class TestPPO(unittest.TestCase): + @classmethod + def setUpClass(cls): + ray.init() + + @classmethod + def tearDownClass(cls): + ray.shutdown() + + def test_ppo_value_bootstrapping(self): + """Test whether PPO's value bootstrapping works properly.""" + + # Build a PPOConfig object with the `SingleAgentEnvRunner` class. + config = ( + ppo.PPOConfig() + .debugging(seed=1) + .environment( # A very simple environment with a terminal reward + "FrozenLake-v1", + env_config={ + "desc": [ + "HG", + "FF", + "SH", + "FH", + ], + "is_slippery": False, + "max_episode_steps": 3, + }, + ) + .env_runners( + num_env_runners=0, + # Flatten discrete observations (into one-hot vectors). + env_to_module_connector=lambda env, spaces, device: FlattenObservations(), + ) + .training( + num_epochs=10, + lr=2e-4, + lambda_=0., # Zero means pure value bootstrapping + gamma=0.9, + train_batch_size=128, + ) + ) + + num_iterations = 20 + + algo = config.build() + + for i in range(num_iterations): + r_mean = algo.train()[ENV_RUNNER_RESULTS][EPISODE_RETURN_MEAN] + print(r_mean) + + # Test value predictions + critic = algo.learner_group._learner._module[DEFAULT_POLICY_ID] + state_values = {} + + for state in [3,2,4,6]: + obs = torch.zeros((8,)).float() + obs[state]+=1 + batch = {Columns.OBS: obs.unsqueeze(0)} + with torch.no_grad(): + value = critic.compute_values(batch).item() + print(f'State {state}: {value:.02f}') + state_values[state] = value + + algo.stop() + # Value bootstrapping should learn this simple environment reliably + self.assertGreater(r_mean, 0.9) + # The value function + self.assertGreater(state_values[3], 0.9) # Immediately terminates with reward 1 + self.assertGreater(state_values[2], 0.8) # One step from terminating with reward 1 + self.assertGreater(state_values[4], 0.7) # Two steps from terminating with reward 1 + self.assertLess(state_values[6], 0.7) # Cannot reach the goal from this state + +if __name__ == "__main__": + import pytest + import sys + + sys.exit(pytest.main(["-v", __file__])) diff --git a/rllib/connectors/learner/add_one_ts_to_episodes_and_truncate.py b/rllib/connectors/learner/add_one_ts_to_episodes_and_truncate.py index fcd3703eeb85..3ba0cc7f9adc 100644 --- a/rllib/connectors/learner/add_one_ts_to_episodes_and_truncate.py +++ b/rllib/connectors/learner/add_one_ts_to_episodes_and_truncate.py @@ -146,10 +146,10 @@ def __call__( len_ + 1, sa_episode, ) - + # Set the original terminating step to false, or value bootstrapping will + # ignore terminal rewards. See ppo/tests/test_value_bootstrapping.py. terminateds = ( - [False for _ in range(len_ - 1)] - + [bool(sa_episode.is_terminated)] + [False for _ in range(len_)] + [True] # extra timestep ) self.add_n_batch_items( From f34c43b1f423406e837f64c6af2c10253e28fb20 Mon Sep 17 00:00:00 2001 From: Matthew Date: Sun, 19 Oct 2025 20:39:30 -0500 Subject: [PATCH 2/6] Linted code Signed-off-by: Matthew --- .../ppo/tests/test_ppo_value_bootstrapping.py | 32 +++++++++++-------- .../add_one_ts_to_episodes_and_truncate.py | 5 +-- 2 files changed, 20 insertions(+), 17 deletions(-) diff --git a/rllib/algorithms/ppo/tests/test_ppo_value_bootstrapping.py b/rllib/algorithms/ppo/tests/test_ppo_value_bootstrapping.py index 15479fcfa8e7..7dfb828d0f92 100644 --- a/rllib/algorithms/ppo/tests/test_ppo_value_bootstrapping.py +++ b/rllib/algorithms/ppo/tests/test_ppo_value_bootstrapping.py @@ -8,6 +8,7 @@ from ray.rllib.utils.metrics import ENV_RUNNER_RESULTS, EPISODE_RETURN_MEAN import torch + class TestPPO(unittest.TestCase): @classmethod def setUpClass(cls): @@ -23,8 +24,8 @@ def test_ppo_value_bootstrapping(self): # Build a PPOConfig object with the `SingleAgentEnvRunner` class. config = ( ppo.PPOConfig() - .debugging(seed=1) - .environment( # A very simple environment with a terminal reward + .debugging(seed=0) + .environment( # A very simple environment with a terminal reward "FrozenLake-v1", env_config={ "desc": [ @@ -45,7 +46,7 @@ def test_ppo_value_bootstrapping(self): .training( num_epochs=10, lr=2e-4, - lambda_=0., # Zero means pure value bootstrapping + lambda_=0.0, # Zero means pure value bootstrapping gamma=0.9, train_batch_size=128, ) @@ -58,28 +59,33 @@ def test_ppo_value_bootstrapping(self): for i in range(num_iterations): r_mean = algo.train()[ENV_RUNNER_RESULTS][EPISODE_RETURN_MEAN] print(r_mean) - + # Test value predictions critic = algo.learner_group._learner._module[DEFAULT_POLICY_ID] state_values = {} - - for state in [3,2,4,6]: + + for state in [3, 2, 4, 6]: obs = torch.zeros((8,)).float() - obs[state]+=1 + obs[state] += 1 batch = {Columns.OBS: obs.unsqueeze(0)} with torch.no_grad(): - value = critic.compute_values(batch).item() - print(f'State {state}: {value:.02f}') + value = critic.compute_values(batch).item() + print(f"State {state}: {value:.02f}") state_values[state] = value algo.stop() # Value bootstrapping should learn this simple environment reliably self.assertGreater(r_mean, 0.9) # The value function - self.assertGreater(state_values[3], 0.9) # Immediately terminates with reward 1 - self.assertGreater(state_values[2], 0.8) # One step from terminating with reward 1 - self.assertGreater(state_values[4], 0.7) # Two steps from terminating with reward 1 - self.assertLess(state_values[6], 0.7) # Cannot reach the goal from this state + self.assertGreater(state_values[3], 0.9) # Immediately terminates with reward 1 + self.assertGreater( + state_values[2], 0.8 + ) # One step from terminating with reward 1 + self.assertGreater( + state_values[4], 0.7 + ) # Two steps from terminating with reward 1 + self.assertLess(state_values[6], 0.7) # Cannot reach the goal from this state + if __name__ == "__main__": import pytest diff --git a/rllib/connectors/learner/add_one_ts_to_episodes_and_truncate.py b/rllib/connectors/learner/add_one_ts_to_episodes_and_truncate.py index 3ba0cc7f9adc..5249f834f414 100644 --- a/rllib/connectors/learner/add_one_ts_to_episodes_and_truncate.py +++ b/rllib/connectors/learner/add_one_ts_to_episodes_and_truncate.py @@ -148,10 +148,7 @@ def __call__( ) # Set the original terminating step to false, or value bootstrapping will # ignore terminal rewards. See ppo/tests/test_value_bootstrapping.py. - terminateds = ( - [False for _ in range(len_)] - + [True] # extra timestep - ) + terminateds = [False for _ in range(len_)] + [True] # extra timestep self.add_n_batch_items( batch, Columns.TERMINATEDS, From 58ffb667ae5d58406b8aadc88ae5becc27239f4e Mon Sep 17 00:00:00 2001 From: Matthew Date: Wed, 7 Jan 2026 19:36:36 -0600 Subject: [PATCH 3/6] Fixed edge case with truncation and added a more direct test for value target calculation Signed-off-by: Matthew --- .../ppo/tests/test_ppo_value_bootstrapping.py | 83 +++++++++++++++++++ .../add_one_ts_to_episodes_and_truncate.py | 7 +- 2 files changed, 87 insertions(+), 3 deletions(-) diff --git a/rllib/algorithms/ppo/tests/test_ppo_value_bootstrapping.py b/rllib/algorithms/ppo/tests/test_ppo_value_bootstrapping.py index 7dfb828d0f92..2a4b35b8104a 100644 --- a/rllib/algorithms/ppo/tests/test_ppo_value_bootstrapping.py +++ b/rllib/algorithms/ppo/tests/test_ppo_value_bootstrapping.py @@ -6,8 +6,66 @@ from ray.rllib.core.columns import Columns from ray.rllib.policy.sample_batch import DEFAULT_POLICY_ID from ray.rllib.utils.metrics import ENV_RUNNER_RESULTS, EPISODE_RETURN_MEAN +from ray.rllib.utils.postprocessing.value_predictions import compute_value_targets +from ray.rllib.utils.postprocessing.zero_padding import unpad_data_if_necessary + +from ray.rllib.env.single_agent_episode import SingleAgentEpisode +from ray.rllib.connectors.learner import AddOneTsToEpisodesAndTruncate + import torch +import numpy as np + +from ray.rllib.connectors.learner.learner_connector_pipeline import ( + LearnerConnectorPipeline, +) +from ray.rllib.connectors.learner import ( + AddColumnsFromEpisodesToTrainBatch, + BatchIndividualItems, + LearnerConnectorPipeline, +) +def simulate_vt_calculation(vfps, rewards, terminateds, truncateds, gamma, lambda_): + # Formatting + episodes = [] + for vfp, r, term, trunc in zip(vfps, rewards, terminateds, truncateds): + episodes.append(SingleAgentEpisode( + observations=[0]*len(vfp), # Include observation after last action + actions=[0]*len(r), + rewards=r, + terminated=term, + truncated=trunc, + len_lookback_buffer=0, + )) + episode_lens = [len(e) for e in episodes] + # Call AddOneTsToEpisodesAndTruncate + pipe = LearnerConnectorPipeline(connectors=[AddOneTsToEpisodesAndTruncate(), AddColumnsFromEpisodesToTrainBatch(), BatchIndividualItems()]) + batch = pipe( + episodes=episodes, + batch={}, + rl_module=None, + explore=False, + shared_data={}, + ) + # Add the last episode's terminated/truncated flags to `terminateds` and `truncateds` + vfps = [v for vfpl in vfps for v in vfpl] + # Compute the value targets + return compute_value_targets( + values=vfps, + rewards=unpad_data_if_necessary( + episode_lens, + np.array(batch[Columns.REWARDS]), + ), + terminateds=unpad_data_if_necessary( + episode_lens, + np.array(batch[Columns.TERMINATEDS]), + ), + truncateds=unpad_data_if_necessary( + episode_lens, + np.array(batch[Columns.TRUNCATEDS]), + ), + gamma=gamma, + lambda_=lambda_, + ) class TestPPO(unittest.TestCase): @classmethod @@ -17,6 +75,31 @@ def setUpClass(cls): @classmethod def tearDownClass(cls): ray.shutdown() + + def test_value_computation(self): + correct = [0.9405, 1., None, 0.9405, 1., None] + two_term = simulate_vt_calculation( + [[0.0, 0.95, 0.95], [0.0, 0.95, 0.95]], # Value head outputs + [[0.0, 1.0], [0.0, 1.0]], # Environment rewards + [True, True], # Terminated flags + [False, False], # Truncated flags + gamma=0.99, lambda_=0.0, + ) + for pred, gt in zip(two_term, correct): + if (gt is not None): + self.assertEqual(pred, gt) + # Test case where an episode is truncated (state value should be included) + correct = [0.9405, 1., None, 0.9405, 1.9405, None] + term_trunc = simulate_vt_calculation( + [[0.0, 0.95, 0.95], [0.0, 0.95, 0.95]], # Value head outputs + [[0.0, 1.0], [0.0, 1.0]], # Environment rewards + [True, False], # Terminated flags + [False, True], # Truncated flags + gamma=0.99, lambda_=0.0, + ) + for pred, gt in zip(term_trunc, correct): + if (gt is not None): + self.assertEqual(pred, gt) def test_ppo_value_bootstrapping(self): """Test whether PPO's value bootstrapping works properly.""" diff --git a/rllib/connectors/learner/add_one_ts_to_episodes_and_truncate.py b/rllib/connectors/learner/add_one_ts_to_episodes_and_truncate.py index 5249f834f414..10663dc932b8 100644 --- a/rllib/connectors/learner/add_one_ts_to_episodes_and_truncate.py +++ b/rllib/connectors/learner/add_one_ts_to_episodes_and_truncate.py @@ -146,9 +146,10 @@ def __call__( len_ + 1, sa_episode, ) - # Set the original terminating step to false, or value bootstrapping will - # ignore terminal rewards. See ppo/tests/test_value_bootstrapping.py. - terminateds = [False for _ in range(len_)] + [True] # extra timestep + terminateds = ( + [False for _ in range(len_)] # Avoid ignoring last-step rewards when lambda=0 + + [bool(sa_episode.is_terminated)] # Use computed value for truncated eps. + ) # extra timestep self.add_n_batch_items( batch, Columns.TERMINATEDS, From ea7666054b02ff9cae115e914b0c8fd22126c14c Mon Sep 17 00:00:00 2001 From: Matthew Date: Wed, 7 Jan 2026 19:47:55 -0600 Subject: [PATCH 4/6] linted Signed-off-by: Matthew --- .../ppo/tests/test_ppo_value_bootstrapping.py | 100 ++++++++++-------- .../add_one_ts_to_episodes_and_truncate.py | 9 +- 2 files changed, 59 insertions(+), 50 deletions(-) diff --git a/rllib/algorithms/ppo/tests/test_ppo_value_bootstrapping.py b/rllib/algorithms/ppo/tests/test_ppo_value_bootstrapping.py index 2a4b35b8104a..e4394c1b6919 100644 --- a/rllib/algorithms/ppo/tests/test_ppo_value_bootstrapping.py +++ b/rllib/algorithms/ppo/tests/test_ppo_value_bootstrapping.py @@ -1,44 +1,48 @@ import unittest +import numpy as np +import torch + import ray import ray.rllib.algorithms.ppo as ppo from ray.rllib.connectors.env_to_module import FlattenObservations +from ray.rllib.connectors.learner import ( + AddColumnsFromEpisodesToTrainBatch, + AddOneTsToEpisodesAndTruncate, + BatchIndividualItems, + LearnerConnectorPipeline, +) from ray.rllib.core.columns import Columns +from ray.rllib.env.single_agent_episode import SingleAgentEpisode from ray.rllib.policy.sample_batch import DEFAULT_POLICY_ID from ray.rllib.utils.metrics import ENV_RUNNER_RESULTS, EPISODE_RETURN_MEAN from ray.rllib.utils.postprocessing.value_predictions import compute_value_targets from ray.rllib.utils.postprocessing.zero_padding import unpad_data_if_necessary -from ray.rllib.env.single_agent_episode import SingleAgentEpisode -from ray.rllib.connectors.learner import AddOneTsToEpisodesAndTruncate - -import torch -import numpy as np - -from ray.rllib.connectors.learner.learner_connector_pipeline import ( - LearnerConnectorPipeline, -) -from ray.rllib.connectors.learner import ( - AddColumnsFromEpisodesToTrainBatch, - BatchIndividualItems, - LearnerConnectorPipeline, -) def simulate_vt_calculation(vfps, rewards, terminateds, truncateds, gamma, lambda_): # Formatting episodes = [] for vfp, r, term, trunc in zip(vfps, rewards, terminateds, truncateds): - episodes.append(SingleAgentEpisode( - observations=[0]*len(vfp), # Include observation after last action - actions=[0]*len(r), - rewards=r, - terminated=term, - truncated=trunc, - len_lookback_buffer=0, - )) + episodes.append( + SingleAgentEpisode( + observations=[0] * len(vfp), # Include observation after last action + actions=[0] * len(r), + rewards=r, + terminated=term, + truncated=trunc, + len_lookback_buffer=0, + ) + ) episode_lens = [len(e) for e in episodes] # Call AddOneTsToEpisodesAndTruncate - pipe = LearnerConnectorPipeline(connectors=[AddOneTsToEpisodesAndTruncate(), AddColumnsFromEpisodesToTrainBatch(), BatchIndividualItems()]) + pipe = LearnerConnectorPipeline( + connectors=[ + AddOneTsToEpisodesAndTruncate(), + AddColumnsFromEpisodesToTrainBatch(), + BatchIndividualItems(), + ] + ) batch = pipe( episodes=episodes, batch={}, @@ -52,21 +56,22 @@ def simulate_vt_calculation(vfps, rewards, terminateds, truncateds, gamma, lambd return compute_value_targets( values=vfps, rewards=unpad_data_if_necessary( - episode_lens, - np.array(batch[Columns.REWARDS]), + episode_lens, + np.array(batch[Columns.REWARDS]), ), terminateds=unpad_data_if_necessary( - episode_lens, - np.array(batch[Columns.TERMINATEDS]), + episode_lens, + np.array(batch[Columns.TERMINATEDS]), ), truncateds=unpad_data_if_necessary( - episode_lens, - np.array(batch[Columns.TRUNCATEDS]), + episode_lens, + np.array(batch[Columns.TRUNCATEDS]), ), gamma=gamma, lambda_=lambda_, ) + class TestPPO(unittest.TestCase): @classmethod def setUpClass(cls): @@ -75,30 +80,32 @@ def setUpClass(cls): @classmethod def tearDownClass(cls): ray.shutdown() - + def test_value_computation(self): - correct = [0.9405, 1., None, 0.9405, 1., None] + correct = [0.9405, 1.0, None, 0.9405, 1.0, None] two_term = simulate_vt_calculation( - [[0.0, 0.95, 0.95], [0.0, 0.95, 0.95]], # Value head outputs - [[0.0, 1.0], [0.0, 1.0]], # Environment rewards - [True, True], # Terminated flags - [False, False], # Truncated flags - gamma=0.99, lambda_=0.0, + [[0.0, 0.95, 0.95], [0.0, 0.95, 0.95]], # Value head outputs + [[0.0, 1.0], [0.0, 1.0]], # Environment rewards + [True, True], # Terminated flags + [False, False], # Truncated flags + gamma=0.99, + lambda_=0.0, ) for pred, gt in zip(two_term, correct): - if (gt is not None): + if gt is not None: self.assertEqual(pred, gt) # Test case where an episode is truncated (state value should be included) - correct = [0.9405, 1., None, 0.9405, 1.9405, None] + correct = [0.9405, 1.0, None, 0.9405, 1.9405, None] term_trunc = simulate_vt_calculation( - [[0.0, 0.95, 0.95], [0.0, 0.95, 0.95]], # Value head outputs - [[0.0, 1.0], [0.0, 1.0]], # Environment rewards - [True, False], # Terminated flags - [False, True], # Truncated flags - gamma=0.99, lambda_=0.0, + [[0.0, 0.95, 0.95], [0.0, 0.95, 0.95]], # Value head outputs + [[0.0, 1.0], [0.0, 1.0]], # Environment rewards + [True, False], # Terminated flags + [False, True], # Truncated flags + gamma=0.99, + lambda_=0.0, ) for pred, gt in zip(term_trunc, correct): - if (gt is not None): + if gt is not None: self.assertEqual(pred, gt) def test_ppo_value_bootstrapping(self): @@ -131,7 +138,7 @@ def test_ppo_value_bootstrapping(self): lr=2e-4, lambda_=0.0, # Zero means pure value bootstrapping gamma=0.9, - train_batch_size=128, + train_batch_size=256, ) ) @@ -171,7 +178,8 @@ def test_ppo_value_bootstrapping(self): if __name__ == "__main__": - import pytest import sys + import pytest + sys.exit(pytest.main(["-v", __file__])) diff --git a/rllib/connectors/learner/add_one_ts_to_episodes_and_truncate.py b/rllib/connectors/learner/add_one_ts_to_episodes_and_truncate.py index 10663dc932b8..0719b81cd3a9 100644 --- a/rllib/connectors/learner/add_one_ts_to_episodes_and_truncate.py +++ b/rllib/connectors/learner/add_one_ts_to_episodes_and_truncate.py @@ -146,10 +146,11 @@ def __call__( len_ + 1, sa_episode, ) - terminateds = ( - [False for _ in range(len_)] # Avoid ignoring last-step rewards when lambda=0 - + [bool(sa_episode.is_terminated)] # Use computed value for truncated eps. - ) # extra timestep + terminateds = [ + False for _ in range(len_) + ] + [ # Avoid ignoring last-step rewards when lambda=0 + bool(sa_episode.is_terminated) + ] # Use computed value for truncated eps. # extra timestep self.add_n_batch_items( batch, Columns.TERMINATEDS, From f54339c20431f9f6e81c569a94f39f37ff5ea0cd Mon Sep 17 00:00:00 2001 From: Matthew Date: Wed, 7 Jan 2026 20:17:14 -0600 Subject: [PATCH 5/6] Fixed floating point equality assert. Signed-off-by: Matthew --- rllib/algorithms/ppo/tests/test_ppo_value_bootstrapping.py | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/rllib/algorithms/ppo/tests/test_ppo_value_bootstrapping.py b/rllib/algorithms/ppo/tests/test_ppo_value_bootstrapping.py index e4394c1b6919..66dde8a830d0 100644 --- a/rllib/algorithms/ppo/tests/test_ppo_value_bootstrapping.py +++ b/rllib/algorithms/ppo/tests/test_ppo_value_bootstrapping.py @@ -93,7 +93,7 @@ def test_value_computation(self): ) for pred, gt in zip(two_term, correct): if gt is not None: - self.assertEqual(pred, gt) + self.assertAlmostEqual(pred, gt) # Test case where an episode is truncated (state value should be included) correct = [0.9405, 1.0, None, 0.9405, 1.9405, None] term_trunc = simulate_vt_calculation( @@ -106,7 +106,7 @@ def test_value_computation(self): ) for pred, gt in zip(term_trunc, correct): if gt is not None: - self.assertEqual(pred, gt) + self.assertAlmostEqual(pred, gt) def test_ppo_value_bootstrapping(self): """Test whether PPO's value bootstrapping works properly.""" From 07bbf2455b171270e152c3545adb4334e7c262d5 Mon Sep 17 00:00:00 2001 From: Matthew Date: Wed, 7 Jan 2026 20:40:37 -0600 Subject: [PATCH 6/6] Fix episode length calculation Signed-off-by: Matthew --- rllib/algorithms/ppo/tests/test_ppo_value_bootstrapping.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/rllib/algorithms/ppo/tests/test_ppo_value_bootstrapping.py b/rllib/algorithms/ppo/tests/test_ppo_value_bootstrapping.py index 66dde8a830d0..8267119e5b7c 100644 --- a/rllib/algorithms/ppo/tests/test_ppo_value_bootstrapping.py +++ b/rllib/algorithms/ppo/tests/test_ppo_value_bootstrapping.py @@ -34,7 +34,6 @@ def simulate_vt_calculation(vfps, rewards, terminateds, truncateds, gamma, lambd len_lookback_buffer=0, ) ) - episode_lens = [len(e) for e in episodes] # Call AddOneTsToEpisodesAndTruncate pipe = LearnerConnectorPipeline( connectors=[ @@ -50,6 +49,7 @@ def simulate_vt_calculation(vfps, rewards, terminateds, truncateds, gamma, lambd explore=False, shared_data={}, ) + episode_lens = [len(e) for e in episodes] # Add the last episode's terminated/truncated flags to `terminateds` and `truncateds` vfps = [v for vfpl in vfps for v in vfpl] # Compute the value targets