Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[Feature] Make PPO compatible with composite actions and log-probs #2665

Merged
merged 12 commits into from
Jan 16, 2025

Conversation

vmoens
Copy link
Contributor

@vmoens vmoens commented Dec 18, 2024

[ghstack-poisoned]
vmoens added a commit that referenced this pull request Dec 18, 2024
ghstack-source-id: cbdaf533a39aeea41e3fbcda4e9d95a116eabfe1
Pull Request resolved: #2665
Copy link

pytorch-bot bot commented Dec 18, 2024

🔗 Helpful Links

🧪 See artifacts and rendered test results at hud.pytorch.org/pr/pytorch/rl/2665

Note: Links to docs will display an error until the docs builds have been completed.

This comment was automatically generated by Dr. CI and updates every 15 minutes.

@facebook-github-bot facebook-github-bot added the CLA Signed This label is managed by the Facebook bot. Authors need to sign the CLA before a PR can be reviewed. label Dec 18, 2024
@vmoens
Copy link
Contributor Author

vmoens commented Dec 18, 2024

In this PR, I propose to let PPO have series of actions defined in the in-keys (rather than a single one) to accomodate CompositeDistributions better.

This PR requires pytorch/tensordict#1146 and pytorch/tensordict#1145 to be merged or checked out.

Here is a demo:
https://gist.github.com/vmoens/46175764240dcbaf311af562b9e53294

cc @matteobettini

@matteobettini
Copy link
Contributor

Cool!

Just to understand a bit, how is this related to multiagent?

I see in the example that you are using different agent groups, but the feature seems to be more suited for composite single-agent actions.

In multiagent, the suggested way to do things was to create a different loss for each group. This is to avoid losses taking a list of dones, rewards, and actions and have to match them.

I think this feature for me makes sense for composite actions within a single-agent or a single marl group (avoiding taking a list of rewards and dones).

@matteobettini
Copy link
Contributor

Also in the example you are using a single module to output actions for multiple groups.
I think also here the way we suggest to do things is to process different groups in different modules, so that each module can go to its loss.

@vmoens
Copy link
Contributor Author

vmoens commented Dec 19, 2024

I don't have a strong feeling RE multiagent or not, the use case that was suggested to me here had a composite action space where each leaf was labelled "agent_x"
I guess that long term there's a version of this where you could have one loss for all, since tensordict now supports arithmetic ops you could perfectly do

log_prob = make_some_tensordict(...)
prev_log_prob = make_some_tensordict(...)
advantage = make_a_tensordict_or_a_tensor(...)
loss = (log_prob - prev_log_prob).exp().clamp(...) * advantage

and your loss will be a tensordict itself.
This would probably break now but I do think we could actually get this to work and simplify the code at the same time (that will require deprecating some default behaviours in CompositeDistribution in v0.9 as announed in tensordict)

@matteobettini
Copy link
Contributor

matteobettini commented Dec 19, 2024

This makes sense for a composite action space yes. But in your PR i see you are also allowing lists of dones and rewards.

This is a bit less trivial as it opens up to a bunch of compatibility usecases if you want to use this in MARL.
The done and reward keys might not be a one to one mapping to actions:

  • groups can have composite actions
  • reward and done could be partially or totally shared across groups (each in a different way possibly)

Supporting all these usecases might become a big headacke which is why I preferred to stick with one reward and done key per loss class.

[ghstack-poisoned]
vmoens added a commit that referenced this pull request Dec 20, 2024
ghstack-source-id: f465f2017843904a510aa06768ced457df987e94
Pull Request resolved: #2665
@vmoens vmoens added the enhancement New feature or request label Dec 20, 2024
[ghstack-poisoned]
vmoens added a commit that referenced this pull request Jan 9, 2025
ghstack-source-id: 3bcf7ebf9619f62d68979f85021a769796da0539
Pull Request resolved: #2665
Copy link

github-actions bot commented Jan 9, 2025

$\color{#D29922}\textsf{\Large⚠\kern{0.2cm}\normalsize Warning}$ Result of CPU Benchmark Tests

Total Benchmarks: 149. Improved: $\large\color{#35bf28}8$. Worsened: $\large\color{#d91a1a}6$.

Expand to view detailed results
Name Max Mean Ops Ops on Repo HEAD Change
test_simple 0.5193s 0.4397s 2.2744 Ops/s 2.2492 Ops/s $\color{#35bf28}+1.12\%$
test_transformed 0.6984s 0.6235s 1.6039 Ops/s 1.5797 Ops/s $\color{#35bf28}+1.53\%$
test_serial 1.4379s 1.3598s 0.7354 Ops/s 0.7229 Ops/s $\color{#35bf28}+1.73\%$
test_parallel 1.2845s 1.2043s 0.8304 Ops/s 0.8252 Ops/s $\color{#35bf28}+0.63\%$
test_step_mdp_speed[True-True-True-True-True] 87.4840μs 30.1470μs 33.1708 KOps/s 33.2811 KOps/s $\color{#d91a1a}-0.33\%$
test_step_mdp_speed[True-True-True-True-False] 53.0080μs 17.8293μs 56.0876 KOps/s 55.5655 KOps/s $\color{#35bf28}+0.94\%$
test_step_mdp_speed[True-True-True-False-True] 70.6210μs 17.1317μs 58.3715 KOps/s 58.0494 KOps/s $\color{#35bf28}+0.55\%$
test_step_mdp_speed[True-True-True-False-False] 42.8200μs 9.9118μs 100.8894 KOps/s 99.9583 KOps/s $\color{#35bf28}+0.93\%$
test_step_mdp_speed[True-True-False-True-True] 75.2100μs 32.5899μs 30.6844 KOps/s 30.8499 KOps/s $\color{#d91a1a}-0.54\%$
test_step_mdp_speed[True-True-False-True-False] 56.3040μs 19.6751μs 50.8256 KOps/s 50.4029 KOps/s $\color{#35bf28}+0.84\%$
test_step_mdp_speed[True-True-False-False-True] 45.2340μs 19.0236μs 52.5662 KOps/s 51.8576 KOps/s $\color{#35bf28}+1.37\%$
test_step_mdp_speed[True-True-False-False-False] 46.0260μs 11.8857μs 84.1349 KOps/s 83.0092 KOps/s $\color{#35bf28}+1.36\%$
test_step_mdp_speed[True-False-True-True-True] 72.7360μs 34.2120μs 29.2295 KOps/s 29.1574 KOps/s $\color{#35bf28}+0.25\%$
test_step_mdp_speed[True-False-True-True-False] 62.8870μs 21.6325μs 46.2268 KOps/s 45.9487 KOps/s $\color{#35bf28}+0.61\%$
test_step_mdp_speed[True-False-True-False-True] 55.4830μs 19.0078μs 52.6099 KOps/s 51.7589 KOps/s $\color{#35bf28}+1.64\%$
test_step_mdp_speed[True-False-True-False-False] 38.6920μs 11.7923μs 84.8008 KOps/s 83.7091 KOps/s $\color{#35bf28}+1.30\%$
test_step_mdp_speed[True-False-False-True-True] 97.6920μs 36.0825μs 27.7143 KOps/s 27.6847 KOps/s $\color{#35bf28}+0.11\%$
test_step_mdp_speed[True-False-False-True-False] 66.2330μs 23.2766μs 42.9616 KOps/s 42.7444 KOps/s $\color{#35bf28}+0.51\%$
test_step_mdp_speed[True-False-False-False-True] 50.5340μs 20.8181μs 48.0351 KOps/s 47.5910 KOps/s $\color{#35bf28}+0.93\%$
test_step_mdp_speed[True-False-False-False-False] 55.4530μs 13.6837μs 73.0795 KOps/s 72.8897 KOps/s $\color{#35bf28}+0.26\%$
test_step_mdp_speed[False-True-True-True-True] 78.2860μs 34.1097μs 29.3171 KOps/s 29.0625 KOps/s $\color{#35bf28}+0.88\%$
test_step_mdp_speed[False-True-True-True-False] 75.2100μs 21.7391μs 46.0001 KOps/s 45.1439 KOps/s $\color{#35bf28}+1.90\%$
test_step_mdp_speed[False-True-True-False-True] 49.5420μs 21.7326μs 46.0138 KOps/s 45.1289 KOps/s $\color{#35bf28}+1.96\%$
test_step_mdp_speed[False-True-True-False-False] 46.7570μs 13.2892μs 75.2491 KOps/s 74.8788 KOps/s $\color{#35bf28}+0.49\%$
test_step_mdp_speed[False-True-False-True-True] 82.1130μs 35.7234μs 27.9928 KOps/s 27.5925 KOps/s $\color{#35bf28}+1.45\%$
test_step_mdp_speed[False-True-False-True-False] 51.4350μs 23.3431μs 42.8392 KOps/s 42.1475 KOps/s $\color{#35bf28}+1.64\%$
test_step_mdp_speed[False-True-False-False-True] 2.8327ms 23.5063μs 42.5418 KOps/s 42.0787 KOps/s $\color{#35bf28}+1.10\%$
test_step_mdp_speed[False-True-False-False-False] 49.5830μs 15.0373μs 66.5011 KOps/s 64.7155 KOps/s $\color{#35bf28}+2.76\%$
test_step_mdp_speed[False-False-True-True-True] 77.3840μs 37.6136μs 26.5862 KOps/s 25.9647 KOps/s $\color{#35bf28}+2.39\%$
test_step_mdp_speed[False-False-True-True-False] 53.6200μs 25.1379μs 39.7805 KOps/s 39.1900 KOps/s $\color{#35bf28}+1.51\%$
test_step_mdp_speed[False-False-True-False-True] 64.1990μs 23.6581μs 42.2688 KOps/s 42.3753 KOps/s $\color{#d91a1a}-0.25\%$
test_step_mdp_speed[False-False-True-False-False] 47.0670μs 14.8414μs 67.3792 KOps/s 65.3115 KOps/s $\color{#35bf28}+3.17\%$
test_step_mdp_speed[False-False-False-True-True] 91.0000μs 38.8681μs 25.7280 KOps/s 25.2006 KOps/s $\color{#35bf28}+2.09\%$
test_step_mdp_speed[False-False-False-True-False] 84.9600μs 26.6176μs 37.5692 KOps/s 36.8096 KOps/s $\color{#35bf28}+2.06\%$
test_step_mdp_speed[False-False-False-False-True] 55.7240μs 24.7743μs 40.3644 KOps/s 39.3064 KOps/s $\color{#35bf28}+2.69\%$
test_step_mdp_speed[False-False-False-False-False] 45.9160μs 16.4763μs 60.6933 KOps/s 59.0454 KOps/s $\color{#35bf28}+2.79\%$
test_values[generalized_advantage_estimate-True-True] 9.8481ms 9.7179ms 102.9028 Ops/s 100.6715 Ops/s $\color{#35bf28}+2.22\%$
test_values[vec_generalized_advantage_estimate-True-True] 35.6738ms 33.4106ms 29.9306 Ops/s 30.1029 Ops/s $\color{#d91a1a}-0.57\%$
test_values[td0_return_estimate-False-False] 0.2298ms 0.1821ms 5.4916 KOps/s 5.6510 KOps/s $\color{#d91a1a}-2.82\%$
test_values[td1_return_estimate-False-False] 24.6340ms 23.9454ms 41.7616 Ops/s 40.7756 Ops/s $\color{#35bf28}+2.42\%$
test_values[vec_td1_return_estimate-False-False] 35.8642ms 33.4868ms 29.8626 Ops/s 29.9635 Ops/s $\color{#d91a1a}-0.34\%$
test_values[td_lambda_return_estimate-True-False] 38.4863ms 34.9209ms 28.6362 Ops/s 28.1891 Ops/s $\color{#35bf28}+1.59\%$
test_values[vec_td_lambda_return_estimate-True-False] 36.4310ms 33.5669ms 29.7912 Ops/s 29.9787 Ops/s $\color{#d91a1a}-0.63\%$
test_gae_speed[generalized_advantage_estimate-False-1-512] 13.8849ms 8.5049ms 117.5788 Ops/s 116.1560 Ops/s $\color{#35bf28}+1.22\%$
test_gae_speed[vec_generalized_advantage_estimate-True-1-512] 2.2590ms 1.9230ms 520.0254 Ops/s 464.2907 Ops/s $\textbf{\color{#35bf28}+12.00\%}$
test_gae_speed[vec_generalized_advantage_estimate-False-1-512] 0.4541ms 0.3510ms 2.8486 KOps/s 2.7775 KOps/s $\color{#35bf28}+2.56\%$
test_gae_speed[vec_generalized_advantage_estimate-True-32-512] 41.8709ms 40.1590ms 24.9010 Ops/s 20.4377 Ops/s $\textbf{\color{#35bf28}+21.84\%}$
test_gae_speed[vec_generalized_advantage_estimate-False-32-512] 4.6618ms 3.0495ms 327.9245 Ops/s 329.3623 Ops/s $\color{#d91a1a}-0.44\%$
test_dqn_speed[False-None] 1.6621ms 1.4006ms 713.9953 Ops/s 719.5813 Ops/s $\color{#d91a1a}-0.78\%$
test_dqn_speed[False-backward] 2.0187ms 1.8918ms 528.6026 Ops/s 534.0775 Ops/s $\color{#d91a1a}-1.03\%$
test_dqn_speed[True-None] 0.7771ms 0.4772ms 2.0954 KOps/s 2.0571 KOps/s $\color{#35bf28}+1.86\%$
test_dqn_speed[True-backward] 0.9465ms 0.8844ms 1.1307 KOps/s 1.0700 KOps/s $\textbf{\color{#35bf28}+5.68\%}$
test_dqn_speed[reduce-overhead-None] 0.8417ms 0.4814ms 2.0772 KOps/s 2.0525 KOps/s $\color{#35bf28}+1.20\%$
test_dqn_speed[reduce-overhead-backward] 1.0523ms 0.9174ms 1.0900 KOps/s 1.0998 KOps/s $\color{#d91a1a}-0.89\%$
test_ddpg_speed[False-None] 3.6548ms 2.9278ms 341.5487 Ops/s 344.1526 Ops/s $\color{#d91a1a}-0.76\%$
test_ddpg_speed[False-backward] 4.2427ms 4.0480ms 247.0350 Ops/s 245.4594 Ops/s $\color{#35bf28}+0.64\%$
test_ddpg_speed[True-None] 1.3556ms 1.0104ms 989.6710 Ops/s 977.0736 Ops/s $\color{#35bf28}+1.29\%$
test_ddpg_speed[True-backward] 1.9697ms 1.9026ms 525.5960 Ops/s 512.1032 Ops/s $\color{#35bf28}+2.63\%$
test_ddpg_speed[reduce-overhead-None] 1.4759ms 1.0174ms 982.8820 Ops/s 981.5225 Ops/s $\color{#35bf28}+0.14\%$
test_ddpg_speed[reduce-overhead-backward] 2.0104ms 1.9035ms 525.3491 Ops/s 521.4086 Ops/s $\color{#35bf28}+0.76\%$
test_sac_speed[False-None] 8.5637ms 8.0754ms 123.8321 Ops/s 122.4525 Ops/s $\color{#35bf28}+1.13\%$
test_sac_speed[False-backward] 11.1327ms 10.8508ms 92.1589 Ops/s 91.9501 Ops/s $\color{#35bf28}+0.23\%$
test_sac_speed[True-None] 2.1802ms 1.8416ms 542.9932 Ops/s 530.1901 Ops/s $\color{#35bf28}+2.41\%$
test_sac_speed[True-backward] 3.6599ms 3.5242ms 283.7559 Ops/s 281.9940 Ops/s $\color{#35bf28}+0.62\%$
test_sac_speed[reduce-overhead-None] 2.4618ms 1.8712ms 534.4162 Ops/s 529.4338 Ops/s $\color{#35bf28}+0.94\%$
test_sac_speed[reduce-overhead-backward] 3.6042ms 3.5502ms 281.6731 Ops/s 282.8909 Ops/s $\color{#d91a1a}-0.43\%$
test_redq_speed[False-None] 14.9348ms 13.2385ms 75.5372 Ops/s 74.9400 Ops/s $\color{#35bf28}+0.80\%$
test_redq_speed[False-backward] 0.2509s 26.8326ms 37.2681 Ops/s 44.2988 Ops/s $\textbf{\color{#d91a1a}-15.87\%}$
test_redq_speed[True-None] 5.1720ms 4.6019ms 217.3037 Ops/s 214.0695 Ops/s $\color{#35bf28}+1.51\%$
test_redq_speed[True-backward] 14.2541ms 12.1087ms 82.5856 Ops/s 82.2371 Ops/s $\color{#35bf28}+0.42\%$
test_redq_speed[reduce-overhead-None] 5.3655ms 4.6408ms 215.4824 Ops/s 212.1425 Ops/s $\color{#35bf28}+1.57\%$
test_redq_speed[reduce-overhead-backward] 12.8127ms 12.2883ms 81.3780 Ops/s 79.9708 Ops/s $\color{#35bf28}+1.76\%$
test_redq_deprec_speed[False-None] 19.7512ms 13.3500ms 74.9066 Ops/s 75.2655 Ops/s $\color{#d91a1a}-0.48\%$
test_redq_deprec_speed[False-backward] 20.4256ms 19.0139ms 52.5930 Ops/s 51.7434 Ops/s $\color{#35bf28}+1.64\%$
test_redq_deprec_speed[True-None] 4.2193ms 3.6329ms 275.2659 Ops/s 276.7728 Ops/s $\color{#d91a1a}-0.54\%$
test_redq_deprec_speed[True-backward] 8.3433ms 8.0288ms 124.5523 Ops/s 123.9137 Ops/s $\color{#35bf28}+0.52\%$
test_redq_deprec_speed[reduce-overhead-None] 4.2952ms 3.6522ms 273.8093 Ops/s 273.8830 Ops/s $\color{#d91a1a}-0.03\%$
test_redq_deprec_speed[reduce-overhead-backward] 9.1717ms 8.1553ms 122.6191 Ops/s 122.2802 Ops/s $\color{#35bf28}+0.28\%$
test_td3_speed[False-None] 8.7342ms 8.1205ms 123.1455 Ops/s 122.7459 Ops/s $\color{#35bf28}+0.33\%$
test_td3_speed[False-backward] 11.6681ms 10.5270ms 94.9939 Ops/s 94.8784 Ops/s $\color{#35bf28}+0.12\%$
test_td3_speed[True-None] 2.0318ms 1.7515ms 570.9391 Ops/s 569.0768 Ops/s $\color{#35bf28}+0.33\%$
test_td3_speed[True-backward] 3.4821ms 3.3692ms 296.8045 Ops/s 299.8141 Ops/s $\color{#d91a1a}-1.00\%$
test_td3_speed[reduce-overhead-None] 1.9615ms 1.7499ms 571.4567 Ops/s 571.6955 Ops/s $\color{#d91a1a}-0.04\%$
test_td3_speed[reduce-overhead-backward] 4.1410ms 3.4619ms 288.8595 Ops/s 300.6578 Ops/s $\color{#d91a1a}-3.92\%$
test_cql_speed[False-None] 40.8806ms 37.8429ms 26.4250 Ops/s 27.0813 Ops/s $\color{#d91a1a}-2.42\%$
test_cql_speed[False-backward] 50.7469ms 47.8195ms 20.9120 Ops/s 21.5597 Ops/s $\color{#d91a1a}-3.00\%$
test_cql_speed[True-None] 17.1675ms 15.7530ms 63.4799 Ops/s 63.2282 Ops/s $\color{#35bf28}+0.40\%$
test_cql_speed[True-backward] 23.6932ms 22.5220ms 44.4010 Ops/s 44.2495 Ops/s $\color{#35bf28}+0.34\%$
test_cql_speed[reduce-overhead-None] 17.5538ms 15.7629ms 63.4400 Ops/s 63.1020 Ops/s $\color{#35bf28}+0.54\%$
test_cql_speed[reduce-overhead-backward] 23.7503ms 22.2802ms 44.8828 Ops/s 43.9638 Ops/s $\color{#35bf28}+2.09\%$
test_a2c_speed[False-None] 8.9352ms 7.2059ms 138.7746 Ops/s 136.4438 Ops/s $\color{#35bf28}+1.71\%$
test_a2c_speed[False-backward] 15.5531ms 14.4464ms 69.2212 Ops/s 68.9459 Ops/s $\color{#35bf28}+0.40\%$
test_a2c_speed[True-None] 4.6579ms 4.1878ms 238.7907 Ops/s 234.9764 Ops/s $\color{#35bf28}+1.62\%$
test_a2c_speed[True-backward] 12.4880ms 10.8452ms 92.2069 Ops/s 92.2449 Ops/s $\color{#d91a1a}-0.04\%$
test_a2c_speed[reduce-overhead-None] 4.9120ms 4.2075ms 237.6688 Ops/s 235.0268 Ops/s $\color{#35bf28}+1.12\%$
test_a2c_speed[reduce-overhead-backward] 11.2005ms 10.7073ms 93.3938 Ops/s 91.9574 Ops/s $\color{#35bf28}+1.56\%$
test_ppo_speed[False-None] 8.6410ms 7.4656ms 133.9477 Ops/s 131.5768 Ops/s $\color{#35bf28}+1.80\%$
test_ppo_speed[False-backward] 14.9632ms 14.6843ms 68.0999 Ops/s 66.8186 Ops/s $\color{#35bf28}+1.92\%$
test_ppo_speed[True-None] 4.3863ms 3.7296ms 268.1263 Ops/s 267.5815 Ops/s $\color{#35bf28}+0.20\%$
test_ppo_speed[True-backward] 10.3318ms 9.6787ms 103.3197 Ops/s 102.5094 Ops/s $\color{#35bf28}+0.79\%$
test_ppo_speed[reduce-overhead-None] 4.2392ms 3.7079ms 269.6949 Ops/s 268.4210 Ops/s $\color{#35bf28}+0.47\%$
test_ppo_speed[reduce-overhead-backward] 9.8804ms 9.5843ms 104.3370 Ops/s 103.9778 Ops/s $\color{#35bf28}+0.35\%$
test_reinforce_speed[False-None] 10.9536ms 6.7409ms 148.3476 Ops/s 151.5232 Ops/s $\color{#d91a1a}-2.10\%$
test_reinforce_speed[False-backward] 10.9703ms 9.7696ms 102.3585 Ops/s 100.3097 Ops/s $\color{#35bf28}+2.04\%$
test_reinforce_speed[True-None] 3.2756ms 2.6396ms 378.8473 Ops/s 371.4139 Ops/s $\color{#35bf28}+2.00\%$
test_reinforce_speed[True-backward] 9.0295ms 8.6263ms 115.9247 Ops/s 116.3064 Ops/s $\color{#d91a1a}-0.33\%$
test_reinforce_speed[reduce-overhead-None] 3.0140ms 2.6441ms 378.2053 Ops/s 374.5602 Ops/s $\color{#35bf28}+0.97\%$
test_reinforce_speed[reduce-overhead-backward] 9.8389ms 8.6163ms 116.0596 Ops/s 116.2857 Ops/s $\color{#d91a1a}-0.19\%$
test_iql_speed[False-None] 37.5598ms 32.4233ms 30.8420 Ops/s 30.2376 Ops/s $\color{#35bf28}+2.00\%$
test_iql_speed[False-backward] 53.0605ms 45.3225ms 22.0641 Ops/s 15.4965 Ops/s $\textbf{\color{#35bf28}+42.38\%}$
test_iql_speed[True-None] 14.0244ms 10.8589ms 92.0900 Ops/s 89.3380 Ops/s $\color{#35bf28}+3.08\%$
test_iql_speed[True-backward] 23.8541ms 21.5798ms 46.3396 Ops/s 45.6008 Ops/s $\color{#35bf28}+1.62\%$
test_iql_speed[reduce-overhead-None] 11.7460ms 10.7661ms 92.8842 Ops/s 91.7004 Ops/s $\color{#35bf28}+1.29\%$
test_iql_speed[reduce-overhead-backward] 23.8647ms 21.8442ms 45.7787 Ops/s 45.7166 Ops/s $\color{#35bf28}+0.14\%$
test_rb_sample[TensorDictReplayBuffer-ListStorage-RandomSampler-4000] 7.6955ms 4.8777ms 205.0166 Ops/s 205.2409 Ops/s $\color{#d91a1a}-0.11\%$
test_rb_sample[TensorDictReplayBuffer-LazyMemmapStorage-RandomSampler-10000] 0.9368ms 0.5164ms 1.9364 KOps/s 1.8953 KOps/s $\color{#35bf28}+2.17\%$
test_rb_sample[TensorDictReplayBuffer-LazyTensorStorage-RandomSampler-10000] 1.0172ms 0.5008ms 1.9969 KOps/s 2.0046 KOps/s $\color{#d91a1a}-0.38\%$
test_rb_sample[TensorDictReplayBuffer-ListStorage-SamplerWithoutReplacement-4000] 7.4391ms 4.6394ms 215.5431 Ops/s 212.7459 Ops/s $\color{#35bf28}+1.31\%$
test_rb_sample[TensorDictReplayBuffer-LazyMemmapStorage-SamplerWithoutReplacement-10000] 1.3259ms 0.5048ms 1.9808 KOps/s 1.9740 KOps/s $\color{#35bf28}+0.35\%$
test_rb_sample[TensorDictReplayBuffer-LazyTensorStorage-SamplerWithoutReplacement-10000] 0.7006ms 0.4786ms 2.0892 KOps/s 2.0498 KOps/s $\color{#35bf28}+1.93\%$
test_rb_sample[TensorDictReplayBuffer-LazyMemmapStorage-sampler6-10000] 2.3399ms 1.6310ms 613.1238 Ops/s 597.4463 Ops/s $\color{#35bf28}+2.62\%$
test_rb_sample[TensorDictReplayBuffer-LazyTensorStorage-sampler7-10000] 2.0825ms 1.5447ms 647.3874 Ops/s 634.7722 Ops/s $\color{#35bf28}+1.99\%$
test_rb_sample[TensorDictPrioritizedReplayBuffer-ListStorage-None-4000] 7.4440ms 4.7491ms 210.5668 Ops/s 205.9031 Ops/s $\color{#35bf28}+2.26\%$
test_rb_sample[TensorDictPrioritizedReplayBuffer-LazyMemmapStorage-None-10000] 1.0852ms 0.6436ms 1.5539 KOps/s 1.5464 KOps/s $\color{#35bf28}+0.48\%$
test_rb_sample[TensorDictPrioritizedReplayBuffer-LazyTensorStorage-None-10000] 1.0297ms 0.6257ms 1.5983 KOps/s 1.5899 KOps/s $\color{#35bf28}+0.53\%$
test_rb_iterate[TensorDictReplayBuffer-ListStorage-RandomSampler-4000] 7.3450ms 4.6439ms 215.3383 Ops/s 212.4700 Ops/s $\color{#35bf28}+1.35\%$
test_rb_iterate[TensorDictReplayBuffer-LazyMemmapStorage-RandomSampler-10000] 1.1018ms 0.5620ms 1.7793 KOps/s 1.9507 KOps/s $\textbf{\color{#d91a1a}-8.79\%}$
test_rb_iterate[TensorDictReplayBuffer-LazyTensorStorage-RandomSampler-10000] 0.4065s 1.0785ms 927.2255 Ops/s 1.9847 KOps/s $\textbf{\color{#d91a1a}-53.28\%}$
test_rb_iterate[TensorDictReplayBuffer-ListStorage-SamplerWithoutReplacement-4000] 5.1262ms 4.6336ms 215.8168 Ops/s 215.1799 Ops/s $\color{#35bf28}+0.30\%$
test_rb_iterate[TensorDictReplayBuffer-LazyMemmapStorage-SamplerWithoutReplacement-10000] 3.5758ms 0.5049ms 1.9807 KOps/s 524.2039 Ops/s $\textbf{\color{#35bf28}+277.84\%}$
test_rb_iterate[TensorDictReplayBuffer-LazyTensorStorage-SamplerWithoutReplacement-10000] 0.7322ms 0.4795ms 2.0853 KOps/s 2.0437 KOps/s $\color{#35bf28}+2.04\%$
test_rb_iterate[TensorDictPrioritizedReplayBuffer-ListStorage-None-4000] 7.4217ms 4.8125ms 207.7916 Ops/s 205.6183 Ops/s $\color{#35bf28}+1.06\%$
test_rb_iterate[TensorDictPrioritizedReplayBuffer-LazyMemmapStorage-None-10000] 2.6647ms 0.6496ms 1.5394 KOps/s 1.5024 KOps/s $\color{#35bf28}+2.46\%$
test_rb_iterate[TensorDictPrioritizedReplayBuffer-LazyTensorStorage-None-10000] 0.9567ms 0.6219ms 1.6080 KOps/s 1.4775 KOps/s $\textbf{\color{#35bf28}+8.83\%}$
test_rb_populate[TensorDictReplayBuffer-ListStorage-RandomSampler-400] 6.2478ms 4.1525ms 240.8216 Ops/s 223.6438 Ops/s $\textbf{\color{#35bf28}+7.68\%}$
test_rb_populate[TensorDictReplayBuffer-LazyMemmapStorage-RandomSampler-400] 3.2856ms 2.1231ms 471.0141 Ops/s 429.0480 Ops/s $\textbf{\color{#35bf28}+9.78\%}$
test_rb_populate[TensorDictReplayBuffer-LazyTensorStorage-RandomSampler-400] 6.4581ms 1.3551ms 737.9382 Ops/s 792.0828 Ops/s $\textbf{\color{#d91a1a}-6.84\%}$
test_rb_populate[TensorDictReplayBuffer-ListStorage-SamplerWithoutReplacement-400] 0.3780s 11.6619ms 85.7494 Ops/s 242.0446 Ops/s $\textbf{\color{#d91a1a}-64.57\%}$
test_rb_populate[TensorDictReplayBuffer-LazyMemmapStorage-SamplerWithoutReplacement-400] 7.7008ms 2.4059ms 415.6414 Ops/s 449.3406 Ops/s $\textbf{\color{#d91a1a}-7.50\%}$
test_rb_populate[TensorDictReplayBuffer-LazyTensorStorage-SamplerWithoutReplacement-400] 5.7068ms 1.3927ms 718.0229 Ops/s 737.3645 Ops/s $\color{#d91a1a}-2.62\%$
test_rb_populate[TensorDictPrioritizedReplayBuffer-ListStorage-None-400] 5.6341ms 4.2780ms 233.7535 Ops/s 231.3539 Ops/s $\color{#35bf28}+1.04\%$
test_rb_populate[TensorDictPrioritizedReplayBuffer-LazyMemmapStorage-None-400] 7.9930ms 2.5038ms 399.3852 Ops/s 416.0047 Ops/s $\color{#d91a1a}-4.00\%$
test_rb_populate[TensorDictPrioritizedReplayBuffer-LazyTensorStorage-None-400] 4.2527ms 1.4904ms 670.9399 Ops/s 675.9410 Ops/s $\color{#d91a1a}-0.74\%$
test_rb_extend_sample[ReplayBuffer-LazyTensorStorage-RandomSampler-10000-10000-100-True] 13.3998ms 12.9218ms 77.3887 Ops/s 74.3654 Ops/s $\color{#35bf28}+4.07\%$
test_rb_extend_sample[ReplayBuffer-LazyTensorStorage-RandomSampler-10000-10000-100-False] 17.1468ms 14.7463ms 67.8135 Ops/s 67.0262 Ops/s $\color{#35bf28}+1.17\%$
test_rb_extend_sample[ReplayBuffer-LazyTensorStorage-RandomSampler-100000-10000-100-True] 22.1954ms 21.5800ms 46.3391 Ops/s 45.0617 Ops/s $\color{#35bf28}+2.83\%$
test_rb_extend_sample[ReplayBuffer-LazyTensorStorage-RandomSampler-100000-10000-100-False] 16.3584ms 14.9057ms 67.0884 Ops/s 66.2170 Ops/s $\color{#35bf28}+1.32\%$
test_rb_extend_sample[ReplayBuffer-LazyTensorStorage-RandomSampler-1000000-10000-100-True] 23.9239ms 21.6505ms 46.1883 Ops/s 45.2855 Ops/s $\color{#35bf28}+1.99\%$
test_rb_extend_sample[ReplayBuffer-LazyTensorStorage-RandomSampler-1000000-10000-100-False] 16.9953ms 16.0335ms 62.3695 Ops/s 60.9682 Ops/s $\color{#35bf28}+2.30\%$

Copy link

github-actions bot commented Jan 9, 2025

$\color{#D29922}\textsf{\Large⚠\kern{0.2cm}\normalsize Warning}$ Result of GPU Benchmark Tests

Total Benchmarks: 149. Improved: $\large\color{#35bf28}18$. Worsened: $\large\color{#d91a1a}13$.

Expand to view detailed results
Name Max Mean Ops Ops on Repo HEAD Change
test_simple 0.8215s 0.7348s 1.3609 Ops/s 1.3611 Ops/s $\color{#d91a1a}-0.02\%$
test_transformed 0.9649s 0.9614s 1.0401 Ops/s 1.0210 Ops/s $\color{#35bf28}+1.88\%$
test_serial 2.1295s 2.1169s 0.4724 Ops/s 0.4720 Ops/s $\color{#35bf28}+0.07\%$
test_parallel 1.8754s 1.8188s 0.5498 Ops/s 0.5493 Ops/s $\color{#35bf28}+0.10\%$
test_step_mdp_speed[True-True-True-True-True] 0.1809ms 40.5043μs 24.6887 KOps/s 25.0159 KOps/s $\color{#d91a1a}-1.31\%$
test_step_mdp_speed[True-True-True-True-False] 50.9010μs 23.7547μs 42.0969 KOps/s 41.8017 KOps/s $\color{#35bf28}+0.71\%$
test_step_mdp_speed[True-True-True-False-True] 66.9310μs 22.3454μs 44.7519 KOps/s 44.6603 KOps/s $\color{#35bf28}+0.21\%$
test_step_mdp_speed[True-True-True-False-False] 55.0410μs 13.0297μs 76.7476 KOps/s 76.7184 KOps/s $\color{#35bf28}+0.04\%$
test_step_mdp_speed[True-True-False-True-True] 89.1720μs 42.4441μs 23.5604 KOps/s 23.2476 KOps/s $\color{#35bf28}+1.35\%$
test_step_mdp_speed[True-True-False-True-False] 66.9710μs 25.8497μs 38.6852 KOps/s 39.5432 KOps/s $\color{#d91a1a}-2.17\%$
test_step_mdp_speed[True-True-False-False-True] 56.1910μs 24.7552μs 40.3956 KOps/s 40.5930 KOps/s $\color{#d91a1a}-0.49\%$
test_step_mdp_speed[True-True-False-False-False] 43.0110μs 15.3643μs 65.0858 KOps/s 64.2076 KOps/s $\color{#35bf28}+1.37\%$
test_step_mdp_speed[True-False-True-True-True] 0.1221ms 44.8019μs 22.3205 KOps/s 21.9907 KOps/s $\color{#35bf28}+1.50\%$
test_step_mdp_speed[True-False-True-True-False] 72.3720μs 27.4130μs 36.4790 KOps/s 35.4472 KOps/s $\color{#35bf28}+2.91\%$
test_step_mdp_speed[True-False-True-False-True] 72.5810μs 24.0601μs 41.5626 KOps/s 40.7677 KOps/s $\color{#35bf28}+1.95\%$
test_step_mdp_speed[True-False-True-False-False] 55.4210μs 15.0420μs 66.4804 KOps/s 65.6629 KOps/s $\color{#35bf28}+1.24\%$
test_step_mdp_speed[True-False-False-True-True] 90.7010μs 46.9273μs 21.3096 KOps/s 21.7906 KOps/s $\color{#d91a1a}-2.21\%$
test_step_mdp_speed[True-False-False-True-False] 75.5420μs 30.0228μs 33.3080 KOps/s 33.3014 KOps/s $\color{#35bf28}+0.02\%$
test_step_mdp_speed[True-False-False-False-True] 62.3210μs 26.2708μs 38.0650 KOps/s 37.5714 KOps/s $\color{#35bf28}+1.31\%$
test_step_mdp_speed[True-False-False-False-False] 59.5310μs 17.3472μs 57.6462 KOps/s 56.1313 KOps/s $\color{#35bf28}+2.70\%$
test_step_mdp_speed[False-True-True-True-True] 0.1220ms 44.1401μs 22.6551 KOps/s 22.1315 KOps/s $\color{#35bf28}+2.37\%$
test_step_mdp_speed[False-True-True-True-False] 69.4010μs 27.6173μs 36.2092 KOps/s 35.6733 KOps/s $\color{#35bf28}+1.50\%$
test_step_mdp_speed[False-True-True-False-True] 80.1610μs 28.2136μs 35.4439 KOps/s 35.1412 KOps/s $\color{#35bf28}+0.86\%$
test_step_mdp_speed[False-True-True-False-False] 54.6310μs 16.9872μs 58.8679 KOps/s 58.5069 KOps/s $\color{#35bf28}+0.62\%$
test_step_mdp_speed[False-True-False-True-True] 0.1098ms 45.4697μs 21.9927 KOps/s 21.1590 KOps/s $\color{#35bf28}+3.94\%$
test_step_mdp_speed[False-True-False-True-False] 61.0510μs 30.0930μs 33.2303 KOps/s 33.1133 KOps/s $\color{#35bf28}+0.35\%$
test_step_mdp_speed[False-True-False-False-True] 3.1778ms 31.3466μs 31.9014 KOps/s 33.0903 KOps/s $\color{#d91a1a}-3.59\%$
test_step_mdp_speed[False-True-False-False-False] 56.0610μs 19.4079μs 51.5255 KOps/s 51.4623 KOps/s $\color{#35bf28}+0.12\%$
test_step_mdp_speed[False-False-True-True-True] 91.4320μs 49.3052μs 20.2818 KOps/s 19.9894 KOps/s $\color{#35bf28}+1.46\%$
test_step_mdp_speed[False-False-True-True-False] 72.8910μs 32.4634μs 30.8039 KOps/s 30.4458 KOps/s $\color{#35bf28}+1.18\%$
test_step_mdp_speed[False-False-True-False-True] 74.3110μs 30.5627μs 32.7196 KOps/s 32.8381 KOps/s $\color{#d91a1a}-0.36\%$
test_step_mdp_speed[False-False-True-False-False] 53.3410μs 19.1799μs 52.1378 KOps/s 51.8123 KOps/s $\color{#35bf28}+0.63\%$
test_step_mdp_speed[False-False-False-True-True] 89.6620μs 50.8874μs 19.6512 KOps/s 19.3276 KOps/s $\color{#35bf28}+1.67\%$
test_step_mdp_speed[False-False-False-True-False] 70.0210μs 35.1690μs 28.4341 KOps/s 28.7889 KOps/s $\color{#d91a1a}-1.23\%$
test_step_mdp_speed[False-False-False-False-True] 0.1057ms 32.1116μs 31.1414 KOps/s 30.9265 KOps/s $\color{#35bf28}+0.69\%$
test_step_mdp_speed[False-False-False-False-False] 57.3410μs 21.2477μs 47.0638 KOps/s 45.9000 KOps/s $\color{#35bf28}+2.54\%$
test_values[generalized_advantage_estimate-True-True] 24.9588ms 24.5390ms 40.7515 Ops/s 41.5518 Ops/s $\color{#d91a1a}-1.93\%$
test_values[vec_generalized_advantage_estimate-True-True] 0.1050s 2.9977ms 333.5946 Ops/s 351.8042 Ops/s $\textbf{\color{#d91a1a}-5.18\%}$
test_values[td0_return_estimate-False-False] 0.1127ms 79.8548μs 12.5227 KOps/s 12.6890 KOps/s $\color{#d91a1a}-1.31\%$
test_values[td1_return_estimate-False-False] 55.3105ms 54.8497ms 18.2316 Ops/s 18.4943 Ops/s $\color{#d91a1a}-1.42\%$
test_values[vec_td1_return_estimate-False-False] 1.2822ms 1.0817ms 924.5026 Ops/s 928.7876 Ops/s $\color{#d91a1a}-0.46\%$
test_values[td_lambda_return_estimate-True-False] 87.7181ms 87.1651ms 11.4725 Ops/s 11.6767 Ops/s $\color{#d91a1a}-1.75\%$
test_values[vec_td_lambda_return_estimate-True-False] 1.2482ms 1.0752ms 930.0881 Ops/s 934.6611 Ops/s $\color{#d91a1a}-0.49\%$
test_gae_speed[generalized_advantage_estimate-False-1-512] 26.1867ms 24.4230ms 40.9450 Ops/s 42.0037 Ops/s $\color{#d91a1a}-2.52\%$
test_gae_speed[vec_generalized_advantage_estimate-True-1-512] 1.0218ms 0.7496ms 1.3340 KOps/s 1.3450 KOps/s $\color{#d91a1a}-0.82\%$
test_gae_speed[vec_generalized_advantage_estimate-False-1-512] 0.7908ms 0.6693ms 1.4942 KOps/s 1.5011 KOps/s $\color{#d91a1a}-0.46\%$
test_gae_speed[vec_generalized_advantage_estimate-True-32-512] 1.5170ms 1.4765ms 677.2561 Ops/s 680.2476 Ops/s $\color{#d91a1a}-0.44\%$
test_gae_speed[vec_generalized_advantage_estimate-False-32-512] 0.7291ms 0.6834ms 1.4633 KOps/s 1.4755 KOps/s $\color{#d91a1a}-0.83\%$
test_dqn_speed[False-None] 6.8569ms 1.5392ms 649.6953 Ops/s 651.7939 Ops/s $\color{#d91a1a}-0.32\%$
test_dqn_speed[False-backward] 2.4358ms 2.1314ms 469.1753 Ops/s 468.1577 Ops/s $\color{#35bf28}+0.22\%$
test_dqn_speed[True-None] 0.9635ms 0.5397ms 1.8530 KOps/s 1.7623 KOps/s $\textbf{\color{#35bf28}+5.15\%}$
test_dqn_speed[True-backward] 1.1830ms 1.0915ms 916.1532 Ops/s 889.5191 Ops/s $\color{#35bf28}+2.99\%$
test_dqn_speed[reduce-overhead-None] 0.9799ms 0.5556ms 1.7997 KOps/s 1.6661 KOps/s $\textbf{\color{#35bf28}+8.02\%}$
test_dqn_speed[reduce-overhead-backward] 1.0487ms 0.9466ms 1.0565 KOps/s 1.0037 KOps/s $\textbf{\color{#35bf28}+5.25\%}$
test_ddpg_speed[False-None] 3.2866ms 2.8551ms 350.2448 Ops/s 342.9420 Ops/s $\color{#35bf28}+2.13\%$
test_ddpg_speed[False-backward] 4.2741ms 4.1044ms 243.6400 Ops/s 238.3552 Ops/s $\color{#35bf28}+2.22\%$
test_ddpg_speed[True-None] 1.5120ms 1.0642ms 939.6704 Ops/s 884.4533 Ops/s $\textbf{\color{#35bf28}+6.24\%}$
test_ddpg_speed[True-backward] 2.2201ms 2.1345ms 468.4870 Ops/s 461.1894 Ops/s $\color{#35bf28}+1.58\%$
test_ddpg_speed[reduce-overhead-None] 1.2221ms 1.0834ms 922.9900 Ops/s 916.4068 Ops/s $\color{#35bf28}+0.72\%$
test_ddpg_speed[reduce-overhead-backward] 1.7125ms 1.6170ms 618.4324 Ops/s 609.2124 Ops/s $\color{#35bf28}+1.51\%$
test_sac_speed[False-None] 8.4740ms 8.0327ms 124.4919 Ops/s 124.4513 Ops/s $\color{#35bf28}+0.03\%$
test_sac_speed[False-backward] 11.4794ms 10.9398ms 91.4093 Ops/s 90.4087 Ops/s $\color{#35bf28}+1.11\%$
test_sac_speed[True-None] 1.9402ms 1.5222ms 656.9417 Ops/s 650.2521 Ops/s $\color{#35bf28}+1.03\%$
test_sac_speed[True-backward] 3.2575ms 3.2019ms 312.3148 Ops/s 295.3007 Ops/s $\textbf{\color{#35bf28}+5.76\%}$
test_sac_speed[reduce-overhead-None] 22.9558ms 12.7764ms 78.2691 Ops/s 79.2941 Ops/s $\color{#d91a1a}-1.29\%$
test_sac_speed[reduce-overhead-backward] 1.3930ms 1.3361ms 748.4616 Ops/s 657.0597 Ops/s $\textbf{\color{#35bf28}+13.91\%}$
test_redq_speed[False-None] 8.2254ms 7.4865ms 133.5746 Ops/s 132.1428 Ops/s $\color{#35bf28}+1.08\%$
test_redq_speed[False-backward] 12.1110ms 11.2774ms 88.6731 Ops/s 85.7011 Ops/s $\color{#35bf28}+3.47\%$
test_redq_speed[True-None] 2.1427ms 1.9779ms 505.5881 Ops/s 490.8505 Ops/s $\color{#35bf28}+3.00\%$
test_redq_speed[True-backward] 3.7529ms 3.6327ms 275.2807 Ops/s 261.7281 Ops/s $\textbf{\color{#35bf28}+5.18\%}$
test_redq_speed[reduce-overhead-None] 2.1649ms 2.0734ms 482.2958 Ops/s 485.4921 Ops/s $\color{#d91a1a}-0.66\%$
test_redq_speed[reduce-overhead-backward] 4.1424ms 3.7375ms 267.5557 Ops/s 276.8512 Ops/s $\color{#d91a1a}-3.36\%$
test_redq_deprec_speed[False-None] 9.8489ms 9.3911ms 106.4840 Ops/s 109.6089 Ops/s $\color{#d91a1a}-2.85\%$
test_redq_deprec_speed[False-backward] 13.2209ms 12.3666ms 80.8628 Ops/s 82.8562 Ops/s $\color{#d91a1a}-2.41\%$
test_redq_deprec_speed[True-None] 2.4729ms 2.3296ms 429.2627 Ops/s 429.4198 Ops/s $\color{#d91a1a}-0.04\%$
test_redq_deprec_speed[True-backward] 4.2080ms 4.0649ms 246.0076 Ops/s 253.2172 Ops/s $\color{#d91a1a}-2.85\%$
test_redq_deprec_speed[reduce-overhead-None] 2.3954ms 2.3325ms 428.7204 Ops/s 427.9175 Ops/s $\color{#35bf28}+0.19\%$
test_redq_deprec_speed[reduce-overhead-backward] 4.4975ms 3.9688ms 251.9635 Ops/s 239.1814 Ops/s $\textbf{\color{#35bf28}+5.34\%}$
test_td3_speed[False-None] 34.3608ms 8.1834ms 122.1979 Ops/s 123.1196 Ops/s $\color{#d91a1a}-0.75\%$
test_td3_speed[False-backward] 10.9389ms 10.2750ms 97.3236 Ops/s 93.7800 Ops/s $\color{#35bf28}+3.78\%$
test_td3_speed[True-None] 1.6045ms 1.5836ms 631.4683 Ops/s 635.6857 Ops/s $\color{#d91a1a}-0.66\%$
test_td3_speed[True-backward] 3.1676ms 3.1020ms 322.3757 Ops/s 295.1765 Ops/s $\textbf{\color{#35bf28}+9.21\%}$
test_td3_speed[reduce-overhead-None] 58.2632ms 25.8424ms 38.6961 Ops/s 38.9572 Ops/s $\color{#d91a1a}-0.67\%$
test_td3_speed[reduce-overhead-backward] 1.3453ms 1.2961ms 771.5378 Ops/s 688.7525 Ops/s $\textbf{\color{#35bf28}+12.02\%}$
test_cql_speed[False-None] 17.3484ms 16.7818ms 59.5882 Ops/s 59.2992 Ops/s $\color{#35bf28}+0.49\%$
test_cql_speed[False-backward] 22.4929ms 21.9949ms 45.4650 Ops/s 44.3754 Ops/s $\color{#35bf28}+2.46\%$
test_cql_speed[True-None] 2.9806ms 2.9035ms 344.4072 Ops/s 342.6605 Ops/s $\color{#35bf28}+0.51\%$
test_cql_speed[True-backward] 5.3783ms 5.0068ms 199.7301 Ops/s 196.5296 Ops/s $\color{#35bf28}+1.63\%$
test_cql_speed[reduce-overhead-None] 0.3557s 14.4530ms 69.1900 Ops/s 75.5720 Ops/s $\textbf{\color{#d91a1a}-8.44\%}$
test_cql_speed[reduce-overhead-backward] 1.7583ms 1.6928ms 590.7249 Ops/s 649.0314 Ops/s $\textbf{\color{#d91a1a}-8.98\%}$
test_a2c_speed[False-None] 3.2891ms 3.2044ms 312.0695 Ops/s 307.8417 Ops/s $\color{#35bf28}+1.37\%$
test_a2c_speed[False-backward] 6.9481ms 6.4233ms 155.6836 Ops/s 161.7433 Ops/s $\color{#d91a1a}-3.75\%$
test_a2c_speed[True-None] 1.1916ms 1.0118ms 988.3346 Ops/s 974.3100 Ops/s $\color{#35bf28}+1.44\%$
test_a2c_speed[True-backward] 2.7607ms 2.7180ms 367.9142 Ops/s 382.4393 Ops/s $\color{#d91a1a}-3.80\%$
test_a2c_speed[reduce-overhead-None] 21.5998ms 11.6216ms 86.0466 Ops/s 90.0484 Ops/s $\color{#d91a1a}-4.44\%$
test_a2c_speed[reduce-overhead-backward] 1.1668ms 1.1114ms 899.7749 Ops/s 861.1543 Ops/s $\color{#35bf28}+4.48\%$
test_ppo_speed[False-None] 3.9177ms 3.6906ms 270.9583 Ops/s 271.0536 Ops/s $\color{#d91a1a}-0.04\%$
test_ppo_speed[False-backward] 7.5016ms 7.0893ms 141.0573 Ops/s 138.8253 Ops/s $\color{#35bf28}+1.61\%$
test_ppo_speed[True-None] 0.9951ms 0.9482ms 1.0547 KOps/s 1.0407 KOps/s $\color{#35bf28}+1.34\%$
test_ppo_speed[True-backward] 2.7235ms 2.6798ms 373.1690 Ops/s 366.9783 Ops/s $\color{#35bf28}+1.69\%$
test_ppo_speed[reduce-overhead-None] 0.5746ms 0.5278ms 1.8946 KOps/s 68.9701 Ops/s $\textbf{\color{#35bf28}+2646.99\%}$
test_ppo_speed[reduce-overhead-backward] 1.1691ms 1.1120ms 899.2818 Ops/s 982.0026 Ops/s $\textbf{\color{#d91a1a}-8.42\%}$
test_reinforce_speed[False-None] 2.4389ms 2.2579ms 442.8852 Ops/s 434.1367 Ops/s $\color{#35bf28}+2.02\%$
test_reinforce_speed[False-backward] 3.8288ms 3.3951ms 294.5439 Ops/s 293.9229 Ops/s $\color{#35bf28}+0.21\%$
test_reinforce_speed[True-None] 0.8662ms 0.8222ms 1.2162 KOps/s 1.1569 KOps/s $\textbf{\color{#35bf28}+5.12\%}$
test_reinforce_speed[True-backward] 2.5577ms 2.5134ms 397.8683 Ops/s 399.9110 Ops/s $\color{#d91a1a}-0.51\%$
test_reinforce_speed[reduce-overhead-None] 0.2909s 12.1019ms 82.6314 Ops/s 87.8813 Ops/s $\textbf{\color{#d91a1a}-5.97\%}$
test_reinforce_speed[reduce-overhead-backward] 1.2069ms 1.1626ms 860.1473 Ops/s 967.7263 Ops/s $\textbf{\color{#d91a1a}-11.12\%}$
test_iql_speed[False-None] 9.9283ms 9.3367ms 107.1041 Ops/s 106.6767 Ops/s $\color{#35bf28}+0.40\%$
test_iql_speed[False-backward] 13.7263ms 13.3251ms 75.0463 Ops/s 76.2022 Ops/s $\color{#d91a1a}-1.52\%$
test_iql_speed[True-None] 2.1078ms 1.7519ms 570.8004 Ops/s 570.2023 Ops/s $\color{#35bf28}+0.10\%$
test_iql_speed[True-backward] 4.6076ms 4.3722ms 228.7183 Ops/s 234.0122 Ops/s $\color{#d91a1a}-2.26\%$
test_iql_speed[reduce-overhead-None] 20.1468ms 11.4850ms 87.0701 Ops/s 86.8130 Ops/s $\color{#35bf28}+0.30\%$
test_iql_speed[reduce-overhead-backward] 1.6945ms 1.6013ms 624.4858 Ops/s 614.3693 Ops/s $\color{#35bf28}+1.65\%$
test_rb_sample[TensorDictReplayBuffer-ListStorage-RandomSampler-4000] 8.0018ms 6.4002ms 156.2442 Ops/s 154.5468 Ops/s $\color{#35bf28}+1.10\%$
test_rb_sample[TensorDictReplayBuffer-LazyMemmapStorage-RandomSampler-10000] 0.5987ms 0.3460ms 2.8904 KOps/s 2.7989 KOps/s $\color{#35bf28}+3.27\%$
test_rb_sample[TensorDictReplayBuffer-LazyTensorStorage-RandomSampler-10000] 0.5748ms 0.3259ms 3.0680 KOps/s 2.9652 KOps/s $\color{#35bf28}+3.47\%$
test_rb_sample[TensorDictReplayBuffer-ListStorage-SamplerWithoutReplacement-4000] 6.3846ms 6.1745ms 161.9572 Ops/s 161.6285 Ops/s $\color{#35bf28}+0.20\%$
test_rb_sample[TensorDictReplayBuffer-LazyMemmapStorage-SamplerWithoutReplacement-10000] 0.9469ms 0.3138ms 3.1865 KOps/s 3.0558 KOps/s $\color{#35bf28}+4.28\%$
test_rb_sample[TensorDictReplayBuffer-LazyTensorStorage-SamplerWithoutReplacement-10000] 0.5458ms 0.2891ms 3.4588 KOps/s 4.1386 KOps/s $\textbf{\color{#d91a1a}-16.43\%}$
test_rb_sample[TensorDictReplayBuffer-LazyMemmapStorage-sampler6-10000] 1.6056ms 1.3648ms 732.6984 Ops/s 794.1349 Ops/s $\textbf{\color{#d91a1a}-7.74\%}$
test_rb_sample[TensorDictReplayBuffer-LazyTensorStorage-sampler7-10000] 1.5165ms 1.2670ms 789.2777 Ops/s 851.0564 Ops/s $\textbf{\color{#d91a1a}-7.26\%}$
test_rb_sample[TensorDictPrioritizedReplayBuffer-ListStorage-None-4000] 6.5062ms 6.3532ms 157.4000 Ops/s 156.4437 Ops/s $\color{#35bf28}+0.61\%$
test_rb_sample[TensorDictPrioritizedReplayBuffer-LazyMemmapStorage-None-10000] 2.0695ms 0.5009ms 1.9965 KOps/s 2.3940 KOps/s $\textbf{\color{#d91a1a}-16.60\%}$
test_rb_sample[TensorDictPrioritizedReplayBuffer-LazyTensorStorage-None-10000] 0.6876ms 0.4782ms 2.0911 KOps/s 2.3411 KOps/s $\textbf{\color{#d91a1a}-10.68\%}$
test_rb_iterate[TensorDictReplayBuffer-ListStorage-RandomSampler-4000] 6.3868ms 6.1673ms 162.1457 Ops/s 159.8224 Ops/s $\color{#35bf28}+1.45\%$
test_rb_iterate[TensorDictReplayBuffer-LazyMemmapStorage-RandomSampler-10000] 0.9050ms 0.2738ms 3.6526 KOps/s 3.4274 KOps/s $\textbf{\color{#35bf28}+6.57\%}$
test_rb_iterate[TensorDictReplayBuffer-LazyTensorStorage-RandomSampler-10000] 0.4811ms 0.2537ms 3.9420 KOps/s 3.1579 KOps/s $\textbf{\color{#35bf28}+24.83\%}$
test_rb_iterate[TensorDictReplayBuffer-ListStorage-SamplerWithoutReplacement-4000] 6.2850ms 6.0736ms 164.6459 Ops/s 162.4610 Ops/s $\color{#35bf28}+1.34\%$
test_rb_iterate[TensorDictReplayBuffer-LazyMemmapStorage-SamplerWithoutReplacement-10000] 0.6303ms 0.2961ms 3.3778 KOps/s 3.0012 KOps/s $\textbf{\color{#35bf28}+12.55\%}$
test_rb_iterate[TensorDictReplayBuffer-LazyTensorStorage-SamplerWithoutReplacement-10000] 0.6184ms 0.2821ms 3.5450 KOps/s 3.2815 KOps/s $\textbf{\color{#35bf28}+8.03\%}$
test_rb_iterate[TensorDictPrioritizedReplayBuffer-ListStorage-None-4000] 6.4353ms 6.2405ms 160.2442 Ops/s 156.5279 Ops/s $\color{#35bf28}+2.37\%$
test_rb_iterate[TensorDictPrioritizedReplayBuffer-LazyMemmapStorage-None-10000] 1.0565ms 0.4991ms 2.0036 KOps/s 2.2206 KOps/s $\textbf{\color{#d91a1a}-9.77\%}$
test_rb_iterate[TensorDictPrioritizedReplayBuffer-LazyTensorStorage-None-10000] 0.9401ms 0.4795ms 2.0857 KOps/s 2.1624 KOps/s $\color{#d91a1a}-3.55\%$
test_rb_populate[TensorDictReplayBuffer-ListStorage-RandomSampler-400] 7.0958ms 5.4961ms 181.9458 Ops/s 184.3960 Ops/s $\color{#d91a1a}-1.33\%$
test_rb_populate[TensorDictReplayBuffer-LazyMemmapStorage-RandomSampler-400] 7.9453ms 2.0559ms 486.4165 Ops/s 430.4125 Ops/s $\textbf{\color{#35bf28}+13.01\%}$
test_rb_populate[TensorDictReplayBuffer-LazyTensorStorage-RandomSampler-400] 7.9404ms 1.2177ms 821.1977 Ops/s 879.1141 Ops/s $\textbf{\color{#d91a1a}-6.59\%}$
test_rb_populate[TensorDictReplayBuffer-ListStorage-SamplerWithoutReplacement-400] 7.3434ms 5.4313ms 184.1192 Ops/s 183.8354 Ops/s $\color{#35bf28}+0.15\%$
test_rb_populate[TensorDictReplayBuffer-LazyMemmapStorage-SamplerWithoutReplacement-400] 8.3253ms 2.0149ms 496.3137 Ops/s 473.7291 Ops/s $\color{#35bf28}+4.77\%$
test_rb_populate[TensorDictReplayBuffer-LazyTensorStorage-SamplerWithoutReplacement-400] 7.0066ms 1.2308ms 812.4764 Ops/s 785.3183 Ops/s $\color{#35bf28}+3.46\%$
test_rb_populate[TensorDictPrioritizedReplayBuffer-ListStorage-None-400] 0.4926s 15.4457ms 64.7431 Ops/s 33.2816 Ops/s $\textbf{\color{#35bf28}+94.53\%}$
test_rb_populate[TensorDictPrioritizedReplayBuffer-LazyMemmapStorage-None-400] 10.8323ms 2.2531ms 443.8255 Ops/s 455.5096 Ops/s $\color{#d91a1a}-2.57\%$
test_rb_populate[TensorDictPrioritizedReplayBuffer-LazyTensorStorage-None-400] 6.3288ms 1.3502ms 740.6333 Ops/s 715.9294 Ops/s $\color{#35bf28}+3.45\%$
test_rb_extend_sample[ReplayBuffer-LazyTensorStorage-RandomSampler-10000-10000-100-True] 15.7634ms 15.4062ms 64.9088 Ops/s 64.0810 Ops/s $\color{#35bf28}+1.29\%$
test_rb_extend_sample[ReplayBuffer-LazyTensorStorage-RandomSampler-10000-10000-100-False] 19.8999ms 17.5595ms 56.9492 Ops/s 56.8124 Ops/s $\color{#35bf28}+0.24\%$
test_rb_extend_sample[ReplayBuffer-LazyTensorStorage-RandomSampler-100000-10000-100-True] 20.0776ms 19.5433ms 51.1684 Ops/s 49.8886 Ops/s $\color{#35bf28}+2.57\%$
test_rb_extend_sample[ReplayBuffer-LazyTensorStorage-RandomSampler-100000-10000-100-False] 19.6548ms 17.5397ms 57.0135 Ops/s 57.2406 Ops/s $\color{#d91a1a}-0.40\%$
test_rb_extend_sample[ReplayBuffer-LazyTensorStorage-RandomSampler-1000000-10000-100-True] 19.9450ms 19.5031ms 51.2738 Ops/s 50.4255 Ops/s $\color{#35bf28}+1.68\%$
test_rb_extend_sample[ReplayBuffer-LazyTensorStorage-RandomSampler-1000000-10000-100-False] 20.8772ms 18.9445ms 52.7859 Ops/s 52.8257 Ops/s $\color{#d91a1a}-0.08\%$

[ghstack-poisoned]
[ghstack-poisoned]
@vmoens
Copy link
Contributor Author

vmoens commented Jan 10, 2025

Supporting all these usecases might become a big headacke which is why I preferred to stick with one reward and done key per loss class.

In its current version, this PR assumes that you can have a multihead action but reward / done are going to be tensors.

What we do is that everytime we need to re-compute anything from the dist, we look at it and if it's a composite dist AND if you didn't explicitly ask to aggregate the log-probs, we get a tensordict of log-probs.

From there, for PPOLoss and KL version, the change is quite trivial.

For ClipPPOLoss, there's a bit of change in the logic: before, we were summing all the log-probs (or weights), then clamping, then multiplying by the advantage.

Now, we first clamp each weight leaf, then sum and multiply. Hopefully that should be more mathematically accurate since but I'm happy to revert this if people yell at me!

There is still something I would like to implement but it could be a bit bc-breaking so I'd rather get people's opinion on it: currently, we kind of assume that users have set the return_log_prob=True in the distribution, but we could spare that. All we need is the params of the dist and the sample to compute the log-prob, and the params are presumably part of your tensordict. So we could rebuild the original dist at low cost during the loss computation.

This is what we have now:

with self.actor_network_params.to_module(
self.actor_network
) if self.functional else contextlib.nullcontext():
dist = self.actor_network.get_dist(tensordict)
try:
prev_log_prob = _maybe_get_or_select(
tensordict, self.tensor_keys.sample_log_prob
)
except KeyError as err:
raise _make_lp_get_error(self.tensor_keys, tensordict, err)

The way I'm thinking about this is to append a PR to this stack where:

  • if the parameters (say, loc and scale) are present, we always recompute the dist. If the log-prob is there too, we tell the user that this is not necessary in a warning;
  • If the parameters are not there, we fall back on a stored log-prob if there is.

cc @louisfaury

@matteobettini
Copy link
Contributor

Why can't you (1) read the logprob if there is and if not (2) recompute it if there are the original dist prams and if not (3) throw error?

@vmoens
Copy link
Contributor Author

vmoens commented Jan 10, 2025

We could, it's mainly a matter of what is the "default" to me.
IMO the default should be not asking users to compute anything we can do ourselves. But I agree that both options are fine (and in eager mode the new behaviour may be slightly slower - although considering the big picture it'll probably be roughly the same if you include the time it takes to compute the lp during inference).

@vmoens
Copy link
Contributor Author

vmoens commented Jan 10, 2025

nm, this is a bit more ambitious than I thought since VTrace requires the log-prob to be present and we want the advantage to be callable outside of the loss

[ghstack-poisoned]
[ghstack-poisoned]
[ghstack-poisoned]
[ghstack-poisoned]
[ghstack-poisoned]
[ghstack-poisoned]
[ghstack-poisoned]
@vmoens vmoens merged commit b05d735 into gh/vmoens/58/base Jan 16, 2025
57 of 65 checks passed
vmoens added a commit that referenced this pull request Jan 16, 2025
ghstack-source-id: c41718e697f9b6edda17d4ddb5bd6d41402b7c30
Pull Request resolved: #2665
@vmoens vmoens deleted the gh/vmoens/58/head branch January 16, 2025 11:15
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
CLA Signed This label is managed by the Facebook bot. Authors need to sign the CLA before a PR can be reviewed. enhancement New feature or request
Projects
None yet
Development

Successfully merging this pull request may close these issues.

3 participants