-
Notifications
You must be signed in to change notification settings - Fork 328
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
[Feature] Make PPO compatible with composite actions and log-probs #2665
Conversation
ghstack-source-id: cbdaf533a39aeea41e3fbcda4e9d95a116eabfe1 Pull Request resolved: #2665
🔗 Helpful Links🧪 See artifacts and rendered test results at hud.pytorch.org/pr/pytorch/rl/2665
Note: Links to docs will display an error until the docs builds have been completed. This comment was automatically generated by Dr. CI and updates every 15 minutes. |
In this PR, I propose to let PPO have series of actions defined in the in-keys (rather than a single one) to accomodate CompositeDistributions better. This PR requires pytorch/tensordict#1146 and pytorch/tensordict#1145 to be merged or checked out. Here is a demo: |
Cool! Just to understand a bit, how is this related to multiagent? I see in the example that you are using different agent groups, but the feature seems to be more suited for composite single-agent actions. In multiagent, the suggested way to do things was to create a different loss for each group. This is to avoid losses taking a list of dones, rewards, and actions and have to match them. I think this feature for me makes sense for composite actions within a single-agent or a single marl group (avoiding taking a list of rewards and dones). |
Also in the example you are using a single module to output actions for multiple groups. |
I don't have a strong feeling RE multiagent or not, the use case that was suggested to me here had a composite action space where each leaf was labelled "agent_x" log_prob = make_some_tensordict(...)
prev_log_prob = make_some_tensordict(...)
advantage = make_a_tensordict_or_a_tensor(...)
loss = (log_prob - prev_log_prob).exp().clamp(...) * advantage and your loss will be a tensordict itself. |
This makes sense for a composite action space yes. But in your PR i see you are also allowing lists of dones and rewards. This is a bit less trivial as it opens up to a bunch of compatibility usecases if you want to use this in MARL.
Supporting all these usecases might become a big headacke which is why I preferred to stick with one reward and done key per loss class. |
ghstack-source-id: f465f2017843904a510aa06768ced457df987e94 Pull Request resolved: #2665
ghstack-source-id: 3bcf7ebf9619f62d68979f85021a769796da0539 Pull Request resolved: #2665
|
Name | Max | Mean | Ops | Ops on Repo HEAD
|
Change |
---|---|---|---|---|---|
test_simple | 0.5193s | 0.4397s | 2.2744 Ops/s | 2.2492 Ops/s | |
test_transformed | 0.6984s | 0.6235s | 1.6039 Ops/s | 1.5797 Ops/s | |
test_serial | 1.4379s | 1.3598s | 0.7354 Ops/s | 0.7229 Ops/s | |
test_parallel | 1.2845s | 1.2043s | 0.8304 Ops/s | 0.8252 Ops/s | |
test_step_mdp_speed[True-True-True-True-True] | 87.4840μs | 30.1470μs | 33.1708 KOps/s | 33.2811 KOps/s | |
test_step_mdp_speed[True-True-True-True-False] | 53.0080μs | 17.8293μs | 56.0876 KOps/s | 55.5655 KOps/s | |
test_step_mdp_speed[True-True-True-False-True] | 70.6210μs | 17.1317μs | 58.3715 KOps/s | 58.0494 KOps/s | |
test_step_mdp_speed[True-True-True-False-False] | 42.8200μs | 9.9118μs | 100.8894 KOps/s | 99.9583 KOps/s | |
test_step_mdp_speed[True-True-False-True-True] | 75.2100μs | 32.5899μs | 30.6844 KOps/s | 30.8499 KOps/s | |
test_step_mdp_speed[True-True-False-True-False] | 56.3040μs | 19.6751μs | 50.8256 KOps/s | 50.4029 KOps/s | |
test_step_mdp_speed[True-True-False-False-True] | 45.2340μs | 19.0236μs | 52.5662 KOps/s | 51.8576 KOps/s | |
test_step_mdp_speed[True-True-False-False-False] | 46.0260μs | 11.8857μs | 84.1349 KOps/s | 83.0092 KOps/s | |
test_step_mdp_speed[True-False-True-True-True] | 72.7360μs | 34.2120μs | 29.2295 KOps/s | 29.1574 KOps/s | |
test_step_mdp_speed[True-False-True-True-False] | 62.8870μs | 21.6325μs | 46.2268 KOps/s | 45.9487 KOps/s | |
test_step_mdp_speed[True-False-True-False-True] | 55.4830μs | 19.0078μs | 52.6099 KOps/s | 51.7589 KOps/s | |
test_step_mdp_speed[True-False-True-False-False] | 38.6920μs | 11.7923μs | 84.8008 KOps/s | 83.7091 KOps/s | |
test_step_mdp_speed[True-False-False-True-True] | 97.6920μs | 36.0825μs | 27.7143 KOps/s | 27.6847 KOps/s | |
test_step_mdp_speed[True-False-False-True-False] | 66.2330μs | 23.2766μs | 42.9616 KOps/s | 42.7444 KOps/s | |
test_step_mdp_speed[True-False-False-False-True] | 50.5340μs | 20.8181μs | 48.0351 KOps/s | 47.5910 KOps/s | |
test_step_mdp_speed[True-False-False-False-False] | 55.4530μs | 13.6837μs | 73.0795 KOps/s | 72.8897 KOps/s | |
test_step_mdp_speed[False-True-True-True-True] | 78.2860μs | 34.1097μs | 29.3171 KOps/s | 29.0625 KOps/s | |
test_step_mdp_speed[False-True-True-True-False] | 75.2100μs | 21.7391μs | 46.0001 KOps/s | 45.1439 KOps/s | |
test_step_mdp_speed[False-True-True-False-True] | 49.5420μs | 21.7326μs | 46.0138 KOps/s | 45.1289 KOps/s | |
test_step_mdp_speed[False-True-True-False-False] | 46.7570μs | 13.2892μs | 75.2491 KOps/s | 74.8788 KOps/s | |
test_step_mdp_speed[False-True-False-True-True] | 82.1130μs | 35.7234μs | 27.9928 KOps/s | 27.5925 KOps/s | |
test_step_mdp_speed[False-True-False-True-False] | 51.4350μs | 23.3431μs | 42.8392 KOps/s | 42.1475 KOps/s | |
test_step_mdp_speed[False-True-False-False-True] | 2.8327ms | 23.5063μs | 42.5418 KOps/s | 42.0787 KOps/s | |
test_step_mdp_speed[False-True-False-False-False] | 49.5830μs | 15.0373μs | 66.5011 KOps/s | 64.7155 KOps/s | |
test_step_mdp_speed[False-False-True-True-True] | 77.3840μs | 37.6136μs | 26.5862 KOps/s | 25.9647 KOps/s | |
test_step_mdp_speed[False-False-True-True-False] | 53.6200μs | 25.1379μs | 39.7805 KOps/s | 39.1900 KOps/s | |
test_step_mdp_speed[False-False-True-False-True] | 64.1990μs | 23.6581μs | 42.2688 KOps/s | 42.3753 KOps/s | |
test_step_mdp_speed[False-False-True-False-False] | 47.0670μs | 14.8414μs | 67.3792 KOps/s | 65.3115 KOps/s | |
test_step_mdp_speed[False-False-False-True-True] | 91.0000μs | 38.8681μs | 25.7280 KOps/s | 25.2006 KOps/s | |
test_step_mdp_speed[False-False-False-True-False] | 84.9600μs | 26.6176μs | 37.5692 KOps/s | 36.8096 KOps/s | |
test_step_mdp_speed[False-False-False-False-True] | 55.7240μs | 24.7743μs | 40.3644 KOps/s | 39.3064 KOps/s | |
test_step_mdp_speed[False-False-False-False-False] | 45.9160μs | 16.4763μs | 60.6933 KOps/s | 59.0454 KOps/s | |
test_values[generalized_advantage_estimate-True-True] | 9.8481ms | 9.7179ms | 102.9028 Ops/s | 100.6715 Ops/s | |
test_values[vec_generalized_advantage_estimate-True-True] | 35.6738ms | 33.4106ms | 29.9306 Ops/s | 30.1029 Ops/s | |
test_values[td0_return_estimate-False-False] | 0.2298ms | 0.1821ms | 5.4916 KOps/s | 5.6510 KOps/s | |
test_values[td1_return_estimate-False-False] | 24.6340ms | 23.9454ms | 41.7616 Ops/s | 40.7756 Ops/s | |
test_values[vec_td1_return_estimate-False-False] | 35.8642ms | 33.4868ms | 29.8626 Ops/s | 29.9635 Ops/s | |
test_values[td_lambda_return_estimate-True-False] | 38.4863ms | 34.9209ms | 28.6362 Ops/s | 28.1891 Ops/s | |
test_values[vec_td_lambda_return_estimate-True-False] | 36.4310ms | 33.5669ms | 29.7912 Ops/s | 29.9787 Ops/s | |
test_gae_speed[generalized_advantage_estimate-False-1-512] | 13.8849ms | 8.5049ms | 117.5788 Ops/s | 116.1560 Ops/s | |
test_gae_speed[vec_generalized_advantage_estimate-True-1-512] | 2.2590ms | 1.9230ms | 520.0254 Ops/s | 464.2907 Ops/s | |
test_gae_speed[vec_generalized_advantage_estimate-False-1-512] | 0.4541ms | 0.3510ms | 2.8486 KOps/s | 2.7775 KOps/s | |
test_gae_speed[vec_generalized_advantage_estimate-True-32-512] | 41.8709ms | 40.1590ms | 24.9010 Ops/s | 20.4377 Ops/s | |
test_gae_speed[vec_generalized_advantage_estimate-False-32-512] | 4.6618ms | 3.0495ms | 327.9245 Ops/s | 329.3623 Ops/s | |
test_dqn_speed[False-None] | 1.6621ms | 1.4006ms | 713.9953 Ops/s | 719.5813 Ops/s | |
test_dqn_speed[False-backward] | 2.0187ms | 1.8918ms | 528.6026 Ops/s | 534.0775 Ops/s | |
test_dqn_speed[True-None] | 0.7771ms | 0.4772ms | 2.0954 KOps/s | 2.0571 KOps/s | |
test_dqn_speed[True-backward] | 0.9465ms | 0.8844ms | 1.1307 KOps/s | 1.0700 KOps/s | |
test_dqn_speed[reduce-overhead-None] | 0.8417ms | 0.4814ms | 2.0772 KOps/s | 2.0525 KOps/s | |
test_dqn_speed[reduce-overhead-backward] | 1.0523ms | 0.9174ms | 1.0900 KOps/s | 1.0998 KOps/s | |
test_ddpg_speed[False-None] | 3.6548ms | 2.9278ms | 341.5487 Ops/s | 344.1526 Ops/s | |
test_ddpg_speed[False-backward] | 4.2427ms | 4.0480ms | 247.0350 Ops/s | 245.4594 Ops/s | |
test_ddpg_speed[True-None] | 1.3556ms | 1.0104ms | 989.6710 Ops/s | 977.0736 Ops/s | |
test_ddpg_speed[True-backward] | 1.9697ms | 1.9026ms | 525.5960 Ops/s | 512.1032 Ops/s | |
test_ddpg_speed[reduce-overhead-None] | 1.4759ms | 1.0174ms | 982.8820 Ops/s | 981.5225 Ops/s | |
test_ddpg_speed[reduce-overhead-backward] | 2.0104ms | 1.9035ms | 525.3491 Ops/s | 521.4086 Ops/s | |
test_sac_speed[False-None] | 8.5637ms | 8.0754ms | 123.8321 Ops/s | 122.4525 Ops/s | |
test_sac_speed[False-backward] | 11.1327ms | 10.8508ms | 92.1589 Ops/s | 91.9501 Ops/s | |
test_sac_speed[True-None] | 2.1802ms | 1.8416ms | 542.9932 Ops/s | 530.1901 Ops/s | |
test_sac_speed[True-backward] | 3.6599ms | 3.5242ms | 283.7559 Ops/s | 281.9940 Ops/s | |
test_sac_speed[reduce-overhead-None] | 2.4618ms | 1.8712ms | 534.4162 Ops/s | 529.4338 Ops/s | |
test_sac_speed[reduce-overhead-backward] | 3.6042ms | 3.5502ms | 281.6731 Ops/s | 282.8909 Ops/s | |
test_redq_speed[False-None] | 14.9348ms | 13.2385ms | 75.5372 Ops/s | 74.9400 Ops/s | |
test_redq_speed[False-backward] | 0.2509s | 26.8326ms | 37.2681 Ops/s | 44.2988 Ops/s | |
test_redq_speed[True-None] | 5.1720ms | 4.6019ms | 217.3037 Ops/s | 214.0695 Ops/s | |
test_redq_speed[True-backward] | 14.2541ms | 12.1087ms | 82.5856 Ops/s | 82.2371 Ops/s | |
test_redq_speed[reduce-overhead-None] | 5.3655ms | 4.6408ms | 215.4824 Ops/s | 212.1425 Ops/s | |
test_redq_speed[reduce-overhead-backward] | 12.8127ms | 12.2883ms | 81.3780 Ops/s | 79.9708 Ops/s | |
test_redq_deprec_speed[False-None] | 19.7512ms | 13.3500ms | 74.9066 Ops/s | 75.2655 Ops/s | |
test_redq_deprec_speed[False-backward] | 20.4256ms | 19.0139ms | 52.5930 Ops/s | 51.7434 Ops/s | |
test_redq_deprec_speed[True-None] | 4.2193ms | 3.6329ms | 275.2659 Ops/s | 276.7728 Ops/s | |
test_redq_deprec_speed[True-backward] | 8.3433ms | 8.0288ms | 124.5523 Ops/s | 123.9137 Ops/s | |
test_redq_deprec_speed[reduce-overhead-None] | 4.2952ms | 3.6522ms | 273.8093 Ops/s | 273.8830 Ops/s | |
test_redq_deprec_speed[reduce-overhead-backward] | 9.1717ms | 8.1553ms | 122.6191 Ops/s | 122.2802 Ops/s | |
test_td3_speed[False-None] | 8.7342ms | 8.1205ms | 123.1455 Ops/s | 122.7459 Ops/s | |
test_td3_speed[False-backward] | 11.6681ms | 10.5270ms | 94.9939 Ops/s | 94.8784 Ops/s | |
test_td3_speed[True-None] | 2.0318ms | 1.7515ms | 570.9391 Ops/s | 569.0768 Ops/s | |
test_td3_speed[True-backward] | 3.4821ms | 3.3692ms | 296.8045 Ops/s | 299.8141 Ops/s | |
test_td3_speed[reduce-overhead-None] | 1.9615ms | 1.7499ms | 571.4567 Ops/s | 571.6955 Ops/s | |
test_td3_speed[reduce-overhead-backward] | 4.1410ms | 3.4619ms | 288.8595 Ops/s | 300.6578 Ops/s | |
test_cql_speed[False-None] | 40.8806ms | 37.8429ms | 26.4250 Ops/s | 27.0813 Ops/s | |
test_cql_speed[False-backward] | 50.7469ms | 47.8195ms | 20.9120 Ops/s | 21.5597 Ops/s | |
test_cql_speed[True-None] | 17.1675ms | 15.7530ms | 63.4799 Ops/s | 63.2282 Ops/s | |
test_cql_speed[True-backward] | 23.6932ms | 22.5220ms | 44.4010 Ops/s | 44.2495 Ops/s | |
test_cql_speed[reduce-overhead-None] | 17.5538ms | 15.7629ms | 63.4400 Ops/s | 63.1020 Ops/s | |
test_cql_speed[reduce-overhead-backward] | 23.7503ms | 22.2802ms | 44.8828 Ops/s | 43.9638 Ops/s | |
test_a2c_speed[False-None] | 8.9352ms | 7.2059ms | 138.7746 Ops/s | 136.4438 Ops/s | |
test_a2c_speed[False-backward] | 15.5531ms | 14.4464ms | 69.2212 Ops/s | 68.9459 Ops/s | |
test_a2c_speed[True-None] | 4.6579ms | 4.1878ms | 238.7907 Ops/s | 234.9764 Ops/s | |
test_a2c_speed[True-backward] | 12.4880ms | 10.8452ms | 92.2069 Ops/s | 92.2449 Ops/s | |
test_a2c_speed[reduce-overhead-None] | 4.9120ms | 4.2075ms | 237.6688 Ops/s | 235.0268 Ops/s | |
test_a2c_speed[reduce-overhead-backward] | 11.2005ms | 10.7073ms | 93.3938 Ops/s | 91.9574 Ops/s | |
test_ppo_speed[False-None] | 8.6410ms | 7.4656ms | 133.9477 Ops/s | 131.5768 Ops/s | |
test_ppo_speed[False-backward] | 14.9632ms | 14.6843ms | 68.0999 Ops/s | 66.8186 Ops/s | |
test_ppo_speed[True-None] | 4.3863ms | 3.7296ms | 268.1263 Ops/s | 267.5815 Ops/s | |
test_ppo_speed[True-backward] | 10.3318ms | 9.6787ms | 103.3197 Ops/s | 102.5094 Ops/s | |
test_ppo_speed[reduce-overhead-None] | 4.2392ms | 3.7079ms | 269.6949 Ops/s | 268.4210 Ops/s | |
test_ppo_speed[reduce-overhead-backward] | 9.8804ms | 9.5843ms | 104.3370 Ops/s | 103.9778 Ops/s | |
test_reinforce_speed[False-None] | 10.9536ms | 6.7409ms | 148.3476 Ops/s | 151.5232 Ops/s | |
test_reinforce_speed[False-backward] | 10.9703ms | 9.7696ms | 102.3585 Ops/s | 100.3097 Ops/s | |
test_reinforce_speed[True-None] | 3.2756ms | 2.6396ms | 378.8473 Ops/s | 371.4139 Ops/s | |
test_reinforce_speed[True-backward] | 9.0295ms | 8.6263ms | 115.9247 Ops/s | 116.3064 Ops/s | |
test_reinforce_speed[reduce-overhead-None] | 3.0140ms | 2.6441ms | 378.2053 Ops/s | 374.5602 Ops/s | |
test_reinforce_speed[reduce-overhead-backward] | 9.8389ms | 8.6163ms | 116.0596 Ops/s | 116.2857 Ops/s | |
test_iql_speed[False-None] | 37.5598ms | 32.4233ms | 30.8420 Ops/s | 30.2376 Ops/s | |
test_iql_speed[False-backward] | 53.0605ms | 45.3225ms | 22.0641 Ops/s | 15.4965 Ops/s | |
test_iql_speed[True-None] | 14.0244ms | 10.8589ms | 92.0900 Ops/s | 89.3380 Ops/s | |
test_iql_speed[True-backward] | 23.8541ms | 21.5798ms | 46.3396 Ops/s | 45.6008 Ops/s | |
test_iql_speed[reduce-overhead-None] | 11.7460ms | 10.7661ms | 92.8842 Ops/s | 91.7004 Ops/s | |
test_iql_speed[reduce-overhead-backward] | 23.8647ms | 21.8442ms | 45.7787 Ops/s | 45.7166 Ops/s | |
test_rb_sample[TensorDictReplayBuffer-ListStorage-RandomSampler-4000] | 7.6955ms | 4.8777ms | 205.0166 Ops/s | 205.2409 Ops/s | |
test_rb_sample[TensorDictReplayBuffer-LazyMemmapStorage-RandomSampler-10000] | 0.9368ms | 0.5164ms | 1.9364 KOps/s | 1.8953 KOps/s | |
test_rb_sample[TensorDictReplayBuffer-LazyTensorStorage-RandomSampler-10000] | 1.0172ms | 0.5008ms | 1.9969 KOps/s | 2.0046 KOps/s | |
test_rb_sample[TensorDictReplayBuffer-ListStorage-SamplerWithoutReplacement-4000] | 7.4391ms | 4.6394ms | 215.5431 Ops/s | 212.7459 Ops/s | |
test_rb_sample[TensorDictReplayBuffer-LazyMemmapStorage-SamplerWithoutReplacement-10000] | 1.3259ms | 0.5048ms | 1.9808 KOps/s | 1.9740 KOps/s | |
test_rb_sample[TensorDictReplayBuffer-LazyTensorStorage-SamplerWithoutReplacement-10000] | 0.7006ms | 0.4786ms | 2.0892 KOps/s | 2.0498 KOps/s | |
test_rb_sample[TensorDictReplayBuffer-LazyMemmapStorage-sampler6-10000] | 2.3399ms | 1.6310ms | 613.1238 Ops/s | 597.4463 Ops/s | |
test_rb_sample[TensorDictReplayBuffer-LazyTensorStorage-sampler7-10000] | 2.0825ms | 1.5447ms | 647.3874 Ops/s | 634.7722 Ops/s | |
test_rb_sample[TensorDictPrioritizedReplayBuffer-ListStorage-None-4000] | 7.4440ms | 4.7491ms | 210.5668 Ops/s | 205.9031 Ops/s | |
test_rb_sample[TensorDictPrioritizedReplayBuffer-LazyMemmapStorage-None-10000] | 1.0852ms | 0.6436ms | 1.5539 KOps/s | 1.5464 KOps/s | |
test_rb_sample[TensorDictPrioritizedReplayBuffer-LazyTensorStorage-None-10000] | 1.0297ms | 0.6257ms | 1.5983 KOps/s | 1.5899 KOps/s | |
test_rb_iterate[TensorDictReplayBuffer-ListStorage-RandomSampler-4000] | 7.3450ms | 4.6439ms | 215.3383 Ops/s | 212.4700 Ops/s | |
test_rb_iterate[TensorDictReplayBuffer-LazyMemmapStorage-RandomSampler-10000] | 1.1018ms | 0.5620ms | 1.7793 KOps/s | 1.9507 KOps/s | |
test_rb_iterate[TensorDictReplayBuffer-LazyTensorStorage-RandomSampler-10000] | 0.4065s | 1.0785ms | 927.2255 Ops/s | 1.9847 KOps/s | |
test_rb_iterate[TensorDictReplayBuffer-ListStorage-SamplerWithoutReplacement-4000] | 5.1262ms | 4.6336ms | 215.8168 Ops/s | 215.1799 Ops/s | |
test_rb_iterate[TensorDictReplayBuffer-LazyMemmapStorage-SamplerWithoutReplacement-10000] | 3.5758ms | 0.5049ms | 1.9807 KOps/s | 524.2039 Ops/s | |
test_rb_iterate[TensorDictReplayBuffer-LazyTensorStorage-SamplerWithoutReplacement-10000] | 0.7322ms | 0.4795ms | 2.0853 KOps/s | 2.0437 KOps/s | |
test_rb_iterate[TensorDictPrioritizedReplayBuffer-ListStorage-None-4000] | 7.4217ms | 4.8125ms | 207.7916 Ops/s | 205.6183 Ops/s | |
test_rb_iterate[TensorDictPrioritizedReplayBuffer-LazyMemmapStorage-None-10000] | 2.6647ms | 0.6496ms | 1.5394 KOps/s | 1.5024 KOps/s | |
test_rb_iterate[TensorDictPrioritizedReplayBuffer-LazyTensorStorage-None-10000] | 0.9567ms | 0.6219ms | 1.6080 KOps/s | 1.4775 KOps/s | |
test_rb_populate[TensorDictReplayBuffer-ListStorage-RandomSampler-400] | 6.2478ms | 4.1525ms | 240.8216 Ops/s | 223.6438 Ops/s | |
test_rb_populate[TensorDictReplayBuffer-LazyMemmapStorage-RandomSampler-400] | 3.2856ms | 2.1231ms | 471.0141 Ops/s | 429.0480 Ops/s | |
test_rb_populate[TensorDictReplayBuffer-LazyTensorStorage-RandomSampler-400] | 6.4581ms | 1.3551ms | 737.9382 Ops/s | 792.0828 Ops/s | |
test_rb_populate[TensorDictReplayBuffer-ListStorage-SamplerWithoutReplacement-400] | 0.3780s | 11.6619ms | 85.7494 Ops/s | 242.0446 Ops/s | |
test_rb_populate[TensorDictReplayBuffer-LazyMemmapStorage-SamplerWithoutReplacement-400] | 7.7008ms | 2.4059ms | 415.6414 Ops/s | 449.3406 Ops/s | |
test_rb_populate[TensorDictReplayBuffer-LazyTensorStorage-SamplerWithoutReplacement-400] | 5.7068ms | 1.3927ms | 718.0229 Ops/s | 737.3645 Ops/s | |
test_rb_populate[TensorDictPrioritizedReplayBuffer-ListStorage-None-400] | 5.6341ms | 4.2780ms | 233.7535 Ops/s | 231.3539 Ops/s | |
test_rb_populate[TensorDictPrioritizedReplayBuffer-LazyMemmapStorage-None-400] | 7.9930ms | 2.5038ms | 399.3852 Ops/s | 416.0047 Ops/s | |
test_rb_populate[TensorDictPrioritizedReplayBuffer-LazyTensorStorage-None-400] | 4.2527ms | 1.4904ms | 670.9399 Ops/s | 675.9410 Ops/s | |
test_rb_extend_sample[ReplayBuffer-LazyTensorStorage-RandomSampler-10000-10000-100-True] | 13.3998ms | 12.9218ms | 77.3887 Ops/s | 74.3654 Ops/s | |
test_rb_extend_sample[ReplayBuffer-LazyTensorStorage-RandomSampler-10000-10000-100-False] | 17.1468ms | 14.7463ms | 67.8135 Ops/s | 67.0262 Ops/s | |
test_rb_extend_sample[ReplayBuffer-LazyTensorStorage-RandomSampler-100000-10000-100-True] | 22.1954ms | 21.5800ms | 46.3391 Ops/s | 45.0617 Ops/s | |
test_rb_extend_sample[ReplayBuffer-LazyTensorStorage-RandomSampler-100000-10000-100-False] | 16.3584ms | 14.9057ms | 67.0884 Ops/s | 66.2170 Ops/s | |
test_rb_extend_sample[ReplayBuffer-LazyTensorStorage-RandomSampler-1000000-10000-100-True] | 23.9239ms | 21.6505ms | 46.1883 Ops/s | 45.2855 Ops/s | |
test_rb_extend_sample[ReplayBuffer-LazyTensorStorage-RandomSampler-1000000-10000-100-False] | 16.9953ms | 16.0335ms | 62.3695 Ops/s | 60.9682 Ops/s |
|
Name | Max | Mean | Ops | Ops on Repo HEAD
|
Change |
---|---|---|---|---|---|
test_simple | 0.8215s | 0.7348s | 1.3609 Ops/s | 1.3611 Ops/s | |
test_transformed | 0.9649s | 0.9614s | 1.0401 Ops/s | 1.0210 Ops/s | |
test_serial | 2.1295s | 2.1169s | 0.4724 Ops/s | 0.4720 Ops/s | |
test_parallel | 1.8754s | 1.8188s | 0.5498 Ops/s | 0.5493 Ops/s | |
test_step_mdp_speed[True-True-True-True-True] | 0.1809ms | 40.5043μs | 24.6887 KOps/s | 25.0159 KOps/s | |
test_step_mdp_speed[True-True-True-True-False] | 50.9010μs | 23.7547μs | 42.0969 KOps/s | 41.8017 KOps/s | |
test_step_mdp_speed[True-True-True-False-True] | 66.9310μs | 22.3454μs | 44.7519 KOps/s | 44.6603 KOps/s | |
test_step_mdp_speed[True-True-True-False-False] | 55.0410μs | 13.0297μs | 76.7476 KOps/s | 76.7184 KOps/s | |
test_step_mdp_speed[True-True-False-True-True] | 89.1720μs | 42.4441μs | 23.5604 KOps/s | 23.2476 KOps/s | |
test_step_mdp_speed[True-True-False-True-False] | 66.9710μs | 25.8497μs | 38.6852 KOps/s | 39.5432 KOps/s | |
test_step_mdp_speed[True-True-False-False-True] | 56.1910μs | 24.7552μs | 40.3956 KOps/s | 40.5930 KOps/s | |
test_step_mdp_speed[True-True-False-False-False] | 43.0110μs | 15.3643μs | 65.0858 KOps/s | 64.2076 KOps/s | |
test_step_mdp_speed[True-False-True-True-True] | 0.1221ms | 44.8019μs | 22.3205 KOps/s | 21.9907 KOps/s | |
test_step_mdp_speed[True-False-True-True-False] | 72.3720μs | 27.4130μs | 36.4790 KOps/s | 35.4472 KOps/s | |
test_step_mdp_speed[True-False-True-False-True] | 72.5810μs | 24.0601μs | 41.5626 KOps/s | 40.7677 KOps/s | |
test_step_mdp_speed[True-False-True-False-False] | 55.4210μs | 15.0420μs | 66.4804 KOps/s | 65.6629 KOps/s | |
test_step_mdp_speed[True-False-False-True-True] | 90.7010μs | 46.9273μs | 21.3096 KOps/s | 21.7906 KOps/s | |
test_step_mdp_speed[True-False-False-True-False] | 75.5420μs | 30.0228μs | 33.3080 KOps/s | 33.3014 KOps/s | |
test_step_mdp_speed[True-False-False-False-True] | 62.3210μs | 26.2708μs | 38.0650 KOps/s | 37.5714 KOps/s | |
test_step_mdp_speed[True-False-False-False-False] | 59.5310μs | 17.3472μs | 57.6462 KOps/s | 56.1313 KOps/s | |
test_step_mdp_speed[False-True-True-True-True] | 0.1220ms | 44.1401μs | 22.6551 KOps/s | 22.1315 KOps/s | |
test_step_mdp_speed[False-True-True-True-False] | 69.4010μs | 27.6173μs | 36.2092 KOps/s | 35.6733 KOps/s | |
test_step_mdp_speed[False-True-True-False-True] | 80.1610μs | 28.2136μs | 35.4439 KOps/s | 35.1412 KOps/s | |
test_step_mdp_speed[False-True-True-False-False] | 54.6310μs | 16.9872μs | 58.8679 KOps/s | 58.5069 KOps/s | |
test_step_mdp_speed[False-True-False-True-True] | 0.1098ms | 45.4697μs | 21.9927 KOps/s | 21.1590 KOps/s | |
test_step_mdp_speed[False-True-False-True-False] | 61.0510μs | 30.0930μs | 33.2303 KOps/s | 33.1133 KOps/s | |
test_step_mdp_speed[False-True-False-False-True] | 3.1778ms | 31.3466μs | 31.9014 KOps/s | 33.0903 KOps/s | |
test_step_mdp_speed[False-True-False-False-False] | 56.0610μs | 19.4079μs | 51.5255 KOps/s | 51.4623 KOps/s | |
test_step_mdp_speed[False-False-True-True-True] | 91.4320μs | 49.3052μs | 20.2818 KOps/s | 19.9894 KOps/s | |
test_step_mdp_speed[False-False-True-True-False] | 72.8910μs | 32.4634μs | 30.8039 KOps/s | 30.4458 KOps/s | |
test_step_mdp_speed[False-False-True-False-True] | 74.3110μs | 30.5627μs | 32.7196 KOps/s | 32.8381 KOps/s | |
test_step_mdp_speed[False-False-True-False-False] | 53.3410μs | 19.1799μs | 52.1378 KOps/s | 51.8123 KOps/s | |
test_step_mdp_speed[False-False-False-True-True] | 89.6620μs | 50.8874μs | 19.6512 KOps/s | 19.3276 KOps/s | |
test_step_mdp_speed[False-False-False-True-False] | 70.0210μs | 35.1690μs | 28.4341 KOps/s | 28.7889 KOps/s | |
test_step_mdp_speed[False-False-False-False-True] | 0.1057ms | 32.1116μs | 31.1414 KOps/s | 30.9265 KOps/s | |
test_step_mdp_speed[False-False-False-False-False] | 57.3410μs | 21.2477μs | 47.0638 KOps/s | 45.9000 KOps/s | |
test_values[generalized_advantage_estimate-True-True] | 24.9588ms | 24.5390ms | 40.7515 Ops/s | 41.5518 Ops/s | |
test_values[vec_generalized_advantage_estimate-True-True] | 0.1050s | 2.9977ms | 333.5946 Ops/s | 351.8042 Ops/s | |
test_values[td0_return_estimate-False-False] | 0.1127ms | 79.8548μs | 12.5227 KOps/s | 12.6890 KOps/s | |
test_values[td1_return_estimate-False-False] | 55.3105ms | 54.8497ms | 18.2316 Ops/s | 18.4943 Ops/s | |
test_values[vec_td1_return_estimate-False-False] | 1.2822ms | 1.0817ms | 924.5026 Ops/s | 928.7876 Ops/s | |
test_values[td_lambda_return_estimate-True-False] | 87.7181ms | 87.1651ms | 11.4725 Ops/s | 11.6767 Ops/s | |
test_values[vec_td_lambda_return_estimate-True-False] | 1.2482ms | 1.0752ms | 930.0881 Ops/s | 934.6611 Ops/s | |
test_gae_speed[generalized_advantage_estimate-False-1-512] | 26.1867ms | 24.4230ms | 40.9450 Ops/s | 42.0037 Ops/s | |
test_gae_speed[vec_generalized_advantage_estimate-True-1-512] | 1.0218ms | 0.7496ms | 1.3340 KOps/s | 1.3450 KOps/s | |
test_gae_speed[vec_generalized_advantage_estimate-False-1-512] | 0.7908ms | 0.6693ms | 1.4942 KOps/s | 1.5011 KOps/s | |
test_gae_speed[vec_generalized_advantage_estimate-True-32-512] | 1.5170ms | 1.4765ms | 677.2561 Ops/s | 680.2476 Ops/s | |
test_gae_speed[vec_generalized_advantage_estimate-False-32-512] | 0.7291ms | 0.6834ms | 1.4633 KOps/s | 1.4755 KOps/s | |
test_dqn_speed[False-None] | 6.8569ms | 1.5392ms | 649.6953 Ops/s | 651.7939 Ops/s | |
test_dqn_speed[False-backward] | 2.4358ms | 2.1314ms | 469.1753 Ops/s | 468.1577 Ops/s | |
test_dqn_speed[True-None] | 0.9635ms | 0.5397ms | 1.8530 KOps/s | 1.7623 KOps/s | |
test_dqn_speed[True-backward] | 1.1830ms | 1.0915ms | 916.1532 Ops/s | 889.5191 Ops/s | |
test_dqn_speed[reduce-overhead-None] | 0.9799ms | 0.5556ms | 1.7997 KOps/s | 1.6661 KOps/s | |
test_dqn_speed[reduce-overhead-backward] | 1.0487ms | 0.9466ms | 1.0565 KOps/s | 1.0037 KOps/s | |
test_ddpg_speed[False-None] | 3.2866ms | 2.8551ms | 350.2448 Ops/s | 342.9420 Ops/s | |
test_ddpg_speed[False-backward] | 4.2741ms | 4.1044ms | 243.6400 Ops/s | 238.3552 Ops/s | |
test_ddpg_speed[True-None] | 1.5120ms | 1.0642ms | 939.6704 Ops/s | 884.4533 Ops/s | |
test_ddpg_speed[True-backward] | 2.2201ms | 2.1345ms | 468.4870 Ops/s | 461.1894 Ops/s | |
test_ddpg_speed[reduce-overhead-None] | 1.2221ms | 1.0834ms | 922.9900 Ops/s | 916.4068 Ops/s | |
test_ddpg_speed[reduce-overhead-backward] | 1.7125ms | 1.6170ms | 618.4324 Ops/s | 609.2124 Ops/s | |
test_sac_speed[False-None] | 8.4740ms | 8.0327ms | 124.4919 Ops/s | 124.4513 Ops/s | |
test_sac_speed[False-backward] | 11.4794ms | 10.9398ms | 91.4093 Ops/s | 90.4087 Ops/s | |
test_sac_speed[True-None] | 1.9402ms | 1.5222ms | 656.9417 Ops/s | 650.2521 Ops/s | |
test_sac_speed[True-backward] | 3.2575ms | 3.2019ms | 312.3148 Ops/s | 295.3007 Ops/s | |
test_sac_speed[reduce-overhead-None] | 22.9558ms | 12.7764ms | 78.2691 Ops/s | 79.2941 Ops/s | |
test_sac_speed[reduce-overhead-backward] | 1.3930ms | 1.3361ms | 748.4616 Ops/s | 657.0597 Ops/s | |
test_redq_speed[False-None] | 8.2254ms | 7.4865ms | 133.5746 Ops/s | 132.1428 Ops/s | |
test_redq_speed[False-backward] | 12.1110ms | 11.2774ms | 88.6731 Ops/s | 85.7011 Ops/s | |
test_redq_speed[True-None] | 2.1427ms | 1.9779ms | 505.5881 Ops/s | 490.8505 Ops/s | |
test_redq_speed[True-backward] | 3.7529ms | 3.6327ms | 275.2807 Ops/s | 261.7281 Ops/s | |
test_redq_speed[reduce-overhead-None] | 2.1649ms | 2.0734ms | 482.2958 Ops/s | 485.4921 Ops/s | |
test_redq_speed[reduce-overhead-backward] | 4.1424ms | 3.7375ms | 267.5557 Ops/s | 276.8512 Ops/s | |
test_redq_deprec_speed[False-None] | 9.8489ms | 9.3911ms | 106.4840 Ops/s | 109.6089 Ops/s | |
test_redq_deprec_speed[False-backward] | 13.2209ms | 12.3666ms | 80.8628 Ops/s | 82.8562 Ops/s | |
test_redq_deprec_speed[True-None] | 2.4729ms | 2.3296ms | 429.2627 Ops/s | 429.4198 Ops/s | |
test_redq_deprec_speed[True-backward] | 4.2080ms | 4.0649ms | 246.0076 Ops/s | 253.2172 Ops/s | |
test_redq_deprec_speed[reduce-overhead-None] | 2.3954ms | 2.3325ms | 428.7204 Ops/s | 427.9175 Ops/s | |
test_redq_deprec_speed[reduce-overhead-backward] | 4.4975ms | 3.9688ms | 251.9635 Ops/s | 239.1814 Ops/s | |
test_td3_speed[False-None] | 34.3608ms | 8.1834ms | 122.1979 Ops/s | 123.1196 Ops/s | |
test_td3_speed[False-backward] | 10.9389ms | 10.2750ms | 97.3236 Ops/s | 93.7800 Ops/s | |
test_td3_speed[True-None] | 1.6045ms | 1.5836ms | 631.4683 Ops/s | 635.6857 Ops/s | |
test_td3_speed[True-backward] | 3.1676ms | 3.1020ms | 322.3757 Ops/s | 295.1765 Ops/s | |
test_td3_speed[reduce-overhead-None] | 58.2632ms | 25.8424ms | 38.6961 Ops/s | 38.9572 Ops/s | |
test_td3_speed[reduce-overhead-backward] | 1.3453ms | 1.2961ms | 771.5378 Ops/s | 688.7525 Ops/s | |
test_cql_speed[False-None] | 17.3484ms | 16.7818ms | 59.5882 Ops/s | 59.2992 Ops/s | |
test_cql_speed[False-backward] | 22.4929ms | 21.9949ms | 45.4650 Ops/s | 44.3754 Ops/s | |
test_cql_speed[True-None] | 2.9806ms | 2.9035ms | 344.4072 Ops/s | 342.6605 Ops/s | |
test_cql_speed[True-backward] | 5.3783ms | 5.0068ms | 199.7301 Ops/s | 196.5296 Ops/s | |
test_cql_speed[reduce-overhead-None] | 0.3557s | 14.4530ms | 69.1900 Ops/s | 75.5720 Ops/s | |
test_cql_speed[reduce-overhead-backward] | 1.7583ms | 1.6928ms | 590.7249 Ops/s | 649.0314 Ops/s | |
test_a2c_speed[False-None] | 3.2891ms | 3.2044ms | 312.0695 Ops/s | 307.8417 Ops/s | |
test_a2c_speed[False-backward] | 6.9481ms | 6.4233ms | 155.6836 Ops/s | 161.7433 Ops/s | |
test_a2c_speed[True-None] | 1.1916ms | 1.0118ms | 988.3346 Ops/s | 974.3100 Ops/s | |
test_a2c_speed[True-backward] | 2.7607ms | 2.7180ms | 367.9142 Ops/s | 382.4393 Ops/s | |
test_a2c_speed[reduce-overhead-None] | 21.5998ms | 11.6216ms | 86.0466 Ops/s | 90.0484 Ops/s | |
test_a2c_speed[reduce-overhead-backward] | 1.1668ms | 1.1114ms | 899.7749 Ops/s | 861.1543 Ops/s | |
test_ppo_speed[False-None] | 3.9177ms | 3.6906ms | 270.9583 Ops/s | 271.0536 Ops/s | |
test_ppo_speed[False-backward] | 7.5016ms | 7.0893ms | 141.0573 Ops/s | 138.8253 Ops/s | |
test_ppo_speed[True-None] | 0.9951ms | 0.9482ms | 1.0547 KOps/s | 1.0407 KOps/s | |
test_ppo_speed[True-backward] | 2.7235ms | 2.6798ms | 373.1690 Ops/s | 366.9783 Ops/s | |
test_ppo_speed[reduce-overhead-None] | 0.5746ms | 0.5278ms | 1.8946 KOps/s | 68.9701 Ops/s | |
test_ppo_speed[reduce-overhead-backward] | 1.1691ms | 1.1120ms | 899.2818 Ops/s | 982.0026 Ops/s | |
test_reinforce_speed[False-None] | 2.4389ms | 2.2579ms | 442.8852 Ops/s | 434.1367 Ops/s | |
test_reinforce_speed[False-backward] | 3.8288ms | 3.3951ms | 294.5439 Ops/s | 293.9229 Ops/s | |
test_reinforce_speed[True-None] | 0.8662ms | 0.8222ms | 1.2162 KOps/s | 1.1569 KOps/s | |
test_reinforce_speed[True-backward] | 2.5577ms | 2.5134ms | 397.8683 Ops/s | 399.9110 Ops/s | |
test_reinforce_speed[reduce-overhead-None] | 0.2909s | 12.1019ms | 82.6314 Ops/s | 87.8813 Ops/s | |
test_reinforce_speed[reduce-overhead-backward] | 1.2069ms | 1.1626ms | 860.1473 Ops/s | 967.7263 Ops/s | |
test_iql_speed[False-None] | 9.9283ms | 9.3367ms | 107.1041 Ops/s | 106.6767 Ops/s | |
test_iql_speed[False-backward] | 13.7263ms | 13.3251ms | 75.0463 Ops/s | 76.2022 Ops/s | |
test_iql_speed[True-None] | 2.1078ms | 1.7519ms | 570.8004 Ops/s | 570.2023 Ops/s | |
test_iql_speed[True-backward] | 4.6076ms | 4.3722ms | 228.7183 Ops/s | 234.0122 Ops/s | |
test_iql_speed[reduce-overhead-None] | 20.1468ms | 11.4850ms | 87.0701 Ops/s | 86.8130 Ops/s | |
test_iql_speed[reduce-overhead-backward] | 1.6945ms | 1.6013ms | 624.4858 Ops/s | 614.3693 Ops/s | |
test_rb_sample[TensorDictReplayBuffer-ListStorage-RandomSampler-4000] | 8.0018ms | 6.4002ms | 156.2442 Ops/s | 154.5468 Ops/s | |
test_rb_sample[TensorDictReplayBuffer-LazyMemmapStorage-RandomSampler-10000] | 0.5987ms | 0.3460ms | 2.8904 KOps/s | 2.7989 KOps/s | |
test_rb_sample[TensorDictReplayBuffer-LazyTensorStorage-RandomSampler-10000] | 0.5748ms | 0.3259ms | 3.0680 KOps/s | 2.9652 KOps/s | |
test_rb_sample[TensorDictReplayBuffer-ListStorage-SamplerWithoutReplacement-4000] | 6.3846ms | 6.1745ms | 161.9572 Ops/s | 161.6285 Ops/s | |
test_rb_sample[TensorDictReplayBuffer-LazyMemmapStorage-SamplerWithoutReplacement-10000] | 0.9469ms | 0.3138ms | 3.1865 KOps/s | 3.0558 KOps/s | |
test_rb_sample[TensorDictReplayBuffer-LazyTensorStorage-SamplerWithoutReplacement-10000] | 0.5458ms | 0.2891ms | 3.4588 KOps/s | 4.1386 KOps/s | |
test_rb_sample[TensorDictReplayBuffer-LazyMemmapStorage-sampler6-10000] | 1.6056ms | 1.3648ms | 732.6984 Ops/s | 794.1349 Ops/s | |
test_rb_sample[TensorDictReplayBuffer-LazyTensorStorage-sampler7-10000] | 1.5165ms | 1.2670ms | 789.2777 Ops/s | 851.0564 Ops/s | |
test_rb_sample[TensorDictPrioritizedReplayBuffer-ListStorage-None-4000] | 6.5062ms | 6.3532ms | 157.4000 Ops/s | 156.4437 Ops/s | |
test_rb_sample[TensorDictPrioritizedReplayBuffer-LazyMemmapStorage-None-10000] | 2.0695ms | 0.5009ms | 1.9965 KOps/s | 2.3940 KOps/s | |
test_rb_sample[TensorDictPrioritizedReplayBuffer-LazyTensorStorage-None-10000] | 0.6876ms | 0.4782ms | 2.0911 KOps/s | 2.3411 KOps/s | |
test_rb_iterate[TensorDictReplayBuffer-ListStorage-RandomSampler-4000] | 6.3868ms | 6.1673ms | 162.1457 Ops/s | 159.8224 Ops/s | |
test_rb_iterate[TensorDictReplayBuffer-LazyMemmapStorage-RandomSampler-10000] | 0.9050ms | 0.2738ms | 3.6526 KOps/s | 3.4274 KOps/s | |
test_rb_iterate[TensorDictReplayBuffer-LazyTensorStorage-RandomSampler-10000] | 0.4811ms | 0.2537ms | 3.9420 KOps/s | 3.1579 KOps/s | |
test_rb_iterate[TensorDictReplayBuffer-ListStorage-SamplerWithoutReplacement-4000] | 6.2850ms | 6.0736ms | 164.6459 Ops/s | 162.4610 Ops/s | |
test_rb_iterate[TensorDictReplayBuffer-LazyMemmapStorage-SamplerWithoutReplacement-10000] | 0.6303ms | 0.2961ms | 3.3778 KOps/s | 3.0012 KOps/s | |
test_rb_iterate[TensorDictReplayBuffer-LazyTensorStorage-SamplerWithoutReplacement-10000] | 0.6184ms | 0.2821ms | 3.5450 KOps/s | 3.2815 KOps/s | |
test_rb_iterate[TensorDictPrioritizedReplayBuffer-ListStorage-None-4000] | 6.4353ms | 6.2405ms | 160.2442 Ops/s | 156.5279 Ops/s | |
test_rb_iterate[TensorDictPrioritizedReplayBuffer-LazyMemmapStorage-None-10000] | 1.0565ms | 0.4991ms | 2.0036 KOps/s | 2.2206 KOps/s | |
test_rb_iterate[TensorDictPrioritizedReplayBuffer-LazyTensorStorage-None-10000] | 0.9401ms | 0.4795ms | 2.0857 KOps/s | 2.1624 KOps/s | |
test_rb_populate[TensorDictReplayBuffer-ListStorage-RandomSampler-400] | 7.0958ms | 5.4961ms | 181.9458 Ops/s | 184.3960 Ops/s | |
test_rb_populate[TensorDictReplayBuffer-LazyMemmapStorage-RandomSampler-400] | 7.9453ms | 2.0559ms | 486.4165 Ops/s | 430.4125 Ops/s | |
test_rb_populate[TensorDictReplayBuffer-LazyTensorStorage-RandomSampler-400] | 7.9404ms | 1.2177ms | 821.1977 Ops/s | 879.1141 Ops/s | |
test_rb_populate[TensorDictReplayBuffer-ListStorage-SamplerWithoutReplacement-400] | 7.3434ms | 5.4313ms | 184.1192 Ops/s | 183.8354 Ops/s | |
test_rb_populate[TensorDictReplayBuffer-LazyMemmapStorage-SamplerWithoutReplacement-400] | 8.3253ms | 2.0149ms | 496.3137 Ops/s | 473.7291 Ops/s | |
test_rb_populate[TensorDictReplayBuffer-LazyTensorStorage-SamplerWithoutReplacement-400] | 7.0066ms | 1.2308ms | 812.4764 Ops/s | 785.3183 Ops/s | |
test_rb_populate[TensorDictPrioritizedReplayBuffer-ListStorage-None-400] | 0.4926s | 15.4457ms | 64.7431 Ops/s | 33.2816 Ops/s | |
test_rb_populate[TensorDictPrioritizedReplayBuffer-LazyMemmapStorage-None-400] | 10.8323ms | 2.2531ms | 443.8255 Ops/s | 455.5096 Ops/s | |
test_rb_populate[TensorDictPrioritizedReplayBuffer-LazyTensorStorage-None-400] | 6.3288ms | 1.3502ms | 740.6333 Ops/s | 715.9294 Ops/s | |
test_rb_extend_sample[ReplayBuffer-LazyTensorStorage-RandomSampler-10000-10000-100-True] | 15.7634ms | 15.4062ms | 64.9088 Ops/s | 64.0810 Ops/s | |
test_rb_extend_sample[ReplayBuffer-LazyTensorStorage-RandomSampler-10000-10000-100-False] | 19.8999ms | 17.5595ms | 56.9492 Ops/s | 56.8124 Ops/s | |
test_rb_extend_sample[ReplayBuffer-LazyTensorStorage-RandomSampler-100000-10000-100-True] | 20.0776ms | 19.5433ms | 51.1684 Ops/s | 49.8886 Ops/s | |
test_rb_extend_sample[ReplayBuffer-LazyTensorStorage-RandomSampler-100000-10000-100-False] | 19.6548ms | 17.5397ms | 57.0135 Ops/s | 57.2406 Ops/s | |
test_rb_extend_sample[ReplayBuffer-LazyTensorStorage-RandomSampler-1000000-10000-100-True] | 19.9450ms | 19.5031ms | 51.2738 Ops/s | 50.4255 Ops/s | |
test_rb_extend_sample[ReplayBuffer-LazyTensorStorage-RandomSampler-1000000-10000-100-False] | 20.8772ms | 18.9445ms | 52.7859 Ops/s | 52.8257 Ops/s |
In its current version, this PR assumes that you can have a multihead action but reward / done are going to be tensors. What we do is that everytime we need to re-compute anything from the dist, we look at it and if it's a composite dist AND if you didn't explicitly ask to aggregate the log-probs, we get a tensordict of log-probs. From there, for PPOLoss and KL version, the change is quite trivial. For ClipPPOLoss, there's a bit of change in the logic: before, we were summing all the log-probs (or weights), then clamping, then multiplying by the advantage. Now, we first clamp each weight leaf, then sum and multiply. Hopefully that should be more mathematically accurate since but I'm happy to revert this if people yell at me! There is still something I would like to implement but it could be a bit bc-breaking so I'd rather get people's opinion on it: currently, we kind of assume that users have set the This is what we have now: Lines 513 to 523 in 86ab9b7
The way I'm thinking about this is to append a PR to this stack where:
cc @louisfaury |
Why can't you (1) read the logprob if there is and if not (2) recompute it if there are the original dist prams and if not (3) throw error? |
We could, it's mainly a matter of what is the "default" to me. |
nm, this is a bit more ambitious than I thought since VTrace requires the log-prob to be present and we want the advantage to be callable outside of the loss |
ghstack-source-id: c41718e697f9b6edda17d4ddb5bd6d41402b7c30 Pull Request resolved: #2665
Stack from ghstack (oldest at bottom):