Releases: pytorch/rl
TorchRL 0.10.1: Fixes and named dimensions in composite specs
Release Notes - v0.10.1
This patch release includes bug fixes, type annotation improvements, and CI enhancements cherry-picked from main.
Bug Fixes
- #3168 - @vmoens - [BugFix] AttributeError in accept_remote_rref_udf_invocation
- Fixed AttributeError in RPC utilities when decorating classes with remote RRef invocation by handling None values in getattr calls
Features
-
#3174 - @vmoens - [Feature] Named dims in Composite
- Added support for named dimensions in Composite specs, enabling better integration with PyTorch's named tensors
-
#3214 - @louisfaury - [Feature] Composite specs can create named tensors with 'zero' and 'rand'
- Extended Composite specs to properly propagate names when creating tensors using
zero()andrand()methods
- Extended Composite specs to properly propagate names when creating tensors using
Type Annotations & Documentation
-
@vmoens - [Typing] Edit wrongfully set str type annotations
- Fixed incorrect string type annotations across 19 files
-
#3175 - @vmoens - [Versioning] Fix doc versioning
- Fixed documentation versioning issues
CI/Build Improvements
-
#3200 - @vmoens - [CI] Use pip install
- Updated CI workflows to use pip install across 41 files
-
@vmoens - [CI] Fix missing librhash0 in doc CI
- Added missing librhash0 dependency in documentation CI
-
@vmoens - [CI] Fix benchmarks for LLMs
- Fixed LLM benchmark CI configurations
-
#3222 - @vmoens - [CI] Upgrade doc python version
- Upgraded Python version in documentation build workflows and added vLLM plugin entry point for FP32 overrides
TorchRL 0.10.0: async LLM inference
TorchRL 0.10.0 Release Notes
What's New in 0.10.0
TorchRL 0.10.0 introduces significant advancements in Large Language Model (LLM) support, new algorithms, enhanced environment integrations, and numerous performance improvements and bug fixes.
Major Features
LLM Support and RLHF
- vLLM Integration Revamp: Complete overhaul of vLLM support with improved batching and performance (#3158) @vmoens
- GRPO (Generalized Reinforcement Learning from Preference Optimization): New algorithm implementation with both sync and async variants (#2970, #2997, #3006) @vmoens
- Expert Iteration and SFT: Implementation of expert iteration algorithms and supervised fine-tuning (#3017) @vmoens
- PPOTrainer: New high-level trainer class for PPO training (#3117) @vmoens
- LLM Tooling: Comprehensive tooling support for LLM environments and transformations (#2966) @vmoens
- Remote LLM Wrappers: Support for remote LLM inference with improved batching (#3116) @vmoens
- Common LLM Generation Interface: Unified kwargs for generation across vLLM and Transformers (#3107) @vmoens
- LLM Transforms:
- Content Management:
ContentBasesystem for structured content handling (#2985) @vmoens - History Tracking: New history system for conversation management (#2965) @vmoens
New Algorithms and Training
- Async SAC: Asynchronous implementation of Soft Actor-Critic (#2946) @vmoens
- Discrete Offline CQL: SOTA implementation for discrete action spaces (#3098) @Ibinarriaga
- Multi-node Ray Support: Enhanced distributed training for GRPO (#3040) @albertbou92
Environment Support
- NPU Support: Added NPU device support for SyncDataCollector (#3155) @lowdy1
- IsaacLab Wrapper: Integration with IsaacLab simulation framework (#2937) @vmoens
- Complete PettingZoo State Support: Enhanced multi-agent environment support (#2953) @JGuzzi
- Minari Integration: Support for loading datasets from local Minari cache (#3068) @Ibinarriaga
Storage and Replay Buffers
- Compressed Storage GPU: GPU acceleration for compressed replay buffers (#3062) @aorenstein68
- Packing: New data packing functionality for efficient storage (#3060) @vmoens
- Ray Replay Buffer: Enhanced distributed replay buffer support (#2949) @vmoens
🔧 Improvements and Enhancements
Performance Optimizations
- Bounded Specs Memory: Single copy optimization for bounded specifications (#2977) @vmoens
- Log-prob Computation: Avoid unnecessary log-prob calculations when retrieving distributions (#3081) @vmoens
- LLM Wrapper Queuing: Performance fixes in LLM wrapper queuing (#3125) @vmoens
- vmap Deactivation: Selective vmap deactivation in objectives for better performance (#2957) @vmoens
API Improvements
- Public SAC Methods: Exposed public methods for SAC algorithm (#3085) @vmoens
- Composite Entropy: Fixed entropy computation for nested keys (#3101) @juandelos
- Multi-head Entropy: Per-head entropy coefficients for PPO (#2972) @Felixs
- ClippedPPOLoss: Support for composite value networks (#3031) @louisfaury
- LineariseRewards: Support for negative weights (#3064) @YoannPoupart
- GAE Typing: Improved typing with optional value networks (#3029) @louisfaury
- Explained Variance: Optional explained variance logging (#3010) @OswaldZink
- Frame Control: Worker-level control over frames_per_batch (#3020) @alexghh
Developer Experience
- Colored Logger: Enhanced logging with colored output (#2967) @vmoens
- Better Error Handling: Improved error catching in env.rollout and rb.add (#3102) @vmoens
- Warning Management: Better warning control for various components (#3099, #3115) @vmoens
- Faster Tests: Optimized test suite performance (#3162) @vmoens
Bug Fixes
Core Functionality
- PRB Serialization: Fixed Prioritized Replay Buffer serialization and loading (#3151, #2963) @vmoens
- Binary Operations: Fixed Binary tensor reshaping and clone operations (#3084, #3077) @LucaCarminati @vmoens
- Categorical Spec: Fixed dtype sampling and masking issues (#2980, #2981) @louisfaury
- ActionMask: Compatibility with composite action specifications (#3022) @louisfaury
- GAE with LSTM: Fixed shifted value computation with LSTM networks (#2941) @vmoens
- Cross-entropy: Fixed log-prob computation for batched input (#3080) @vmoens
Environment and Wrapper Fixes
- TransformedEnv: Fixed in-place modification of specs (#3076) @vmoens
- Parallel Environments: Fixed partial and nested done states (#2959) @vmoens
- Gym Actions: Fixed single action passing when action key is not "action" (#2942) @vmoens
- Brax Memory: Fixed memory leak in Brax environments (#3052) @vmoens
- Atari Patching: Fixed patching for NonTensorData observations (#3091) @marcosGR
Collector and Replay Buffer Fixes
- LLMCollector: Fixed trajectory collection when multiple trajectories complete (#3018) @albertbou92
- Postprocessing: Consistent postprocessing when using replay buffers in collectors (#3144) @vmoens
- Weight Updates: Fixed original weights retrieval in collectors (#2951) @vmoens
- Transform Handling: Fixed transform application and metadata preservation (#3047, #3050) @vmoens
Compatibility and Infrastructure
- PyTorch 2.1.1: Fixed compatibility issues (#3157) @vmoens
- NPU Attribute: Fixed missing NPU attribute (#3159) @vmoens
- CUDA Graph: Fixed update_policy_weights_ with CUDA graphs (#3003) @vmoens
- Stream Capturing: Robust CUDA stream capturing calls (#2950) @vmoens
Documentation and Tutorials
- DQN with RNN Tutorial: Upgraded tutorial with latest best practices (#3152) @vmoens
- LLM API Documentation: Comprehensive documentation for LLM environments and transforms (#2991) @vmoens
- Multi-head Entropy: Better documentation for multi-head entropy usage (#3109) @vmoens
- LSTM Module: Fixed import examples in documentation (#3138) @arvindcr4
- A2C Documentation: Updated AcceptedKeys documentation (#2987) @simeet-n
- History API: Added missing docstrings for History functionality (#3083) @vmoens
- Multi-agent PPO: Fixed tutorial issues (#2940) @matteobettini
- WeightUpdater: Updated documentation after renaming (#3007) @albertbou92
Infrastructure and CI
- Pre-commit Updates: Updated formatting and linting tools (#3108) @vmoens
- Benchmark CI: Fixed benchmark runs and added missing dependencies (#3092, #3163) @vmoens
- Windows CI: Fixed Windows continuous integration (#3028) @vmoens
- Old Dependencies: Fixed CI for older dependency versions (#3165) @vmoens
- C++ Linting: Fixed C++ code linting issues (#3129) @vmoens
- Build System: Improved pyproject.toml usage and versioning (#3089, #3166) @vmoens
🏆 Contributors
Special thanks to all contributors who made this release possible:
- @albertbou92 (Albert Bou) - GRPO multi-node support and LLM improvements
- @Ibinarriaga - CQL offline algorithm and Minari integration
- @aorenstein68 (Adrian Orenstein) - Compressed storage GPU support
- @louisfaury (Louis Faury) - Categorical spec and PPO improvements
- @LucaCarminati (Luca Carminati) - Binary tensor fixes
- @JGuzzi (Jérôme Guzzi) - PettingZoo state support
- @lowdy1 - NPU device support
- @Felixs (Felix Sittenauer) - Multi-head entropy coefficients
- @YoannPoupart (Yoann Poupart) - LineariseRewards improvements
- @OswaldZink (Oswald Zink) - Explained variance logging
- @alexghh (Alexandre Ghelfi) - Frame control improvements
- @marcosGR (Marcos Galletero Romero) - Atari patching fixes
- @matteobettini (Matteo Bettini) - Tutorial fixes
- @simeet-n (Simeet Nayan) - Documentation improvements
- @arvindcr4 - Documentation fixes
- @felixy12 (Felix Yu) - State dict reference fixes
- @SendhilPanchadsaram (Sendhil Panchadsaram) - Documentation typo fixes
- @abhishekunique (Abhishek) - WandB logger and value estimation improvements
- @骑马小猫 - DQN module typo fix
- @ZainRizvi (Zain Rizvi) - CI improvements and meta-pytorch migration
- @mikayla-gawarecki (Mikayla Gawarecki) - Usage tracking and ConditionalPolicySwitch
🔗 Compatibility
- PyTorch: Compatible with PyTorch 2.1.1+ -- recommended >=2.8.0,<2.9.0 for full compatibility
- TensorDict: Updated to work with TensorDict 0.10+
- Python: Supports Python 3.9+
📦 Installation
pip install torchrl==0.10.0For the latest features:
pip install git+https://github.com/pytorch/rl.git@release/0.10.0v0.9.2: Bug fixes and perf improvements
TorchRL 0.9.2 Release Notes
This release focuses on bug fixes, performance improvements, and code quality enhancements.
🚀 New Features
- LineariseRewards: Now supports negative weights for more flexible reward shaping (#3064)
🐛 Bug Fixes
- Fixed policy reference handling in state dictionaries (#3043)
- Improved unbatched data handling in LLM wrappers (#3070)
- Fixed cross-entropy log-probability computation for batched inputs (#3080)
- Fixed Binary
clone()operations (#3077) - Fixed in-place spec modifications in
TransformedEnv(#3076)
⚡ Performance Improvements
- Optimized distribution sampling by avoiding unnecessary log-probability computations (#3081)
🔧 Code Quality
- Standardized coefficient naming in A2C and PPO algorithms (#3079)
📦 Installation
pip install torchrl==0.9.2Thanks to all contributors: @felixy12, @Xmaster6y, @louisfaury and @LCarmi
v0.9.1: fix for history-based vLLM and Transformers wrappers
Fixes an critical issue with vLLMWrapper and TransformersWrapper, where a stack of History objects is resent to stack, resulting in a bug.
TorchRL 0.9.0 Release Notes
We are excited to announce the release of TorchRL 0.9.0! This release introduces a comprehensive LLM API for language model fine-tuning, extensive torch.compile compatibility across all algorithms, and numerous performance improvements.
🚀 Major Features
🤖 LLM API - Complete Framework for Language Model Fine-tuning
TorchRL now includes a comprehensive LLM API for post-training and fine-tuning of language models! This new framework provides everything you need for RLHF, supervised fine-tuning, and tool-augmented training:
- Unified LLM Wrappers
torchrl.modules.llm: Seamless integration with Hugging Face models and vLLM inference engines - Conversation Management
torchrl.data.llm.history: AdvancedHistoryclass for multi-turn dialogue with automatic chat template detection - Tool Integration
torchrl.envs.llm.transforms: Built-in support for Python code execution, function calling, and custom tool transforms - Specialized Objectives
torchrl.objectives.llm: GRPO (Group Relative Policy Optimization) and SFT loss functions optimized for language models - High-Performance Collectors
torchrl.collectors.llm: Async data collection with distributed training support - Flexible Environments
torchrl.envs.llm: Transform-based architecture for reward computation, data loading, and conversation augmentation
The LLM API follows TorchRL's modular design principles, allowing you to mix and match components for your specific use case. Check out the complete documentation and GRPO implementation example to get started!
Unified LLM Wrappers
- TransformersWrapper: Seamless integration with Hugging Face models
- vLLMWrapper: High-performance inference with vLLM engines
- Consistent API: Both wrappers provide unified input/output interfaces using TensorClass objects
- Multiple input modes: Support for history, text, and tokenized inputs
- Configurable outputs: Text, tokens, masks, and log probabilities
Advanced Conversation Management
- History class: Advanced bidirectional conversation management with automatic chat template detection
- Multi-model support: Automatic template detection for various model families (Qwen, DialoGPT, Falcon, DeepSeek, etc.)
- Assistant token masking: Identify which tokens were generated by the assistant for RL applications
- Tool calling support: Handle function calls and tool responses in conversations
- Batch operations: Efficient tensor operations for processing multiple conversations
🛠️ Tool Integration
- PythonInterpreter transform: Built-in Python code execution capabilities
- MCPToolTransform: General tool calling support
- Extensible architecture: Easy to add custom tool transforms
- Safe execution: Controlled environment for tool execution
🎯 Specialized Objectives
- GRPOLoss: Group Relative Policy Optimization loss function optimized for language models
- SFTLoss: Supervised fine-tuning loss with assistant token masking support
- MCAdvantage: Monte-Carlo advantage estimation for LLM training
- KL divergence rewards: Built-in KL penalty computation
⚡ High-Performance Collectors
- LLMCollector: Async data collection with distributed training support
- RayLLMCollector: Multi-node distributed collection using Ray
- Weight synchronization: Automatic model weight updates across distributed setups
- Trajectory management: Efficient handling of variable-length conversations
🔄 Flexible Environments
- ChatEnv: Transform-based architecture for conversation management
- Transform-based rewards: Modular reward computation and data loading
- Dataset integration: Built-in support for loading prompts from datasets
- Thinking prompts: Chain-of-thought reasoning support
📚 Complete Implementation Example
A full GRPO implementation is provided in sota-implementations/grpo/ with:
- Multi-GPU support with efficient device management
- Mixed precision training
- Gradient accumulation
- Automatic checkpointing
- Comprehensive logging with Weights & Biases
- Hydra configuration system
- Asynchronous training support with Ray
🆕 New Features
LLM API Components
- LLMMaskedCategorical (#3041) - Categorical distribution with masking for LLM token selection
- AddThinkingPrompt transform (#3027) - Add chain-of-thought reasoning prompts
- MCPToolTransform (#2993) - Model Context Protocol tool integration
- PythonInterpreter transform (#2988) - Python code execution in LLM environments
- ContentBase (#2985) - Base class for structured content in LLM workflows
- LLM Tooling (#2966) - Comprehensive tool integration framework
- History API (#2965) - Advanced conversation management system
- LLM collector (#2879) - Specialized data collection for language models
- vLLM wrapper (#2830) - High-performance vLLM integration
- Transformers policy (#2825) - Hugging Face transformers integration
Environment Enhancements
- IsaacLab wrapper (#2937) - NVIDIA Isaac Lab environment support
- Complete PettingZooWrapper state support (#2953) - Full state management for multi-agent environments
- ConditionalPolicySwitch transform (#2711) - Dynamic policy switching based on conditions
- Async environments (#2864) - Asynchronous environment execution
- VecNormV2 (#2867) - Improved vector normalization with batched environment support
Algorithm Improvements
- Async GRPO (#2997) - Asynchronous Group Relative Policy Optimization
- Expert Iteration and SFT (#3017) - Expert iteration and supervised fine-tuning algorithms
- Async SAC (#2946) - Asynchronous Soft Actor-Critic implementation
- Multi-node Ray support for GRPO (#3040) - Distributed GRPO training
Data Management
- RayReplayBuffer (#2835) - Distributed replay buffer using Ray
- RayReplayBuffer usage examples (#2949) - Comprehensive usage examples
- Policy factory for collectors (#2841) - Flexible policy creation in collectors
- Local and Remote WeightUpdaters (#2848) - Distributed weight synchronization
Performance Optimizations
- Deactivate vmap in objectives (#2957) - Improved performance by disabling vectorized operations
- Hold a single copy of low/high in bounded specs (#2977) - Memory optimization for bounded specifications
- Use TensorDict._new_unsafe in step (#2905) - Performance improvement in environment steps
- Memoize calls to encode and related methods (#2907) - Caching for improved performance
Utility Features
- Compose.pop (#3026) - Remove transforms from composition
- Add optional Explained Variance logging (#3010) - Enhanced logging capabilities
- Enabling worker level control on frames_per_batch (#3020) - Granular control over data collection
- collector.start() (#2935) - Explicit collector lifecycle management
- Timer transform (#2806) - Timing capabilities for environments
- MultiAction transform (#2779) - Multi-action environment support
- Transform for partial steps (#2777) - Partial step execution support
🔧 Performance Improvements
- VecNormV2: Improved vector normalization with better bias correction timing (#2900, #2901)
- MaskedCategorical cross_entropy: Faster loss computation (#2882)
- Avoid padding in transformer wrapper: Memory and performance optimization (#2881)
- Set padded token log-prob to 0.0: Improved numerical stability (#2857)
- Better device checks: Enhanced device management (#2909)
- Local dtype maps: Optimized dtype handling (#2936)
🐛 Bug Fixes
LLM API Fixes
- Variable length vllm wrapper answer stacking (#3049) - Fixed stacking issues with variable-length responses
- LLMCollector trajectory collection methods (#3018) - Fixed trajectory collection when multiple trajectories complete simultaneously
- Fix IFEval GRPO runs (#3012) - Resolved issues with IFEval dataset runs
- Fix cuda cache empty in GRPO scripts (#3016) - Memory management improvements
- Right log-prob size in transformer wrapper (#2856) - Fixed log probability tensor sizing
- Fix gc import (#2862) - Import error resolution
Environment Fixes
- Brax memory leak fix (#3052) - Resolved memory leaks in Brax environments
- Fix behavior of partial, nested dones in PEnv and TEnv (#2959) - Improved done state handling
- Fix shifted value computation with an LSTM (#2941) - LSTM value computation fixes
- Fix single action pass to gym when action key is not "action" (#2942) - Action key handling improvements
- Fix PEnv device copies (#2840) - Device management in parallel environments
Data Management Fixes
-...
v0.8.1: Async collectors patch
Async Collector execution
This release major upgrades is a patch to collector.start() to allow collectors (single or multi-proc) to run asynchronously. #2935
An example is provided in the async SAC example. #2946
Single-agent reset
Fixes #2958 where partial resets are not handled correctly when a BatchedEnv is transformed - as the "done" checks were inconsistent. We now enforce that root "_reset" entries always precede their respective leaves.
Fix shifted values in GAE using LSTMs
Using an LSTM within GAE is facilitated by ensuring that shifted=True and shifted=False work properly (with appropriate warnings/errors if other hyperparameters need to be set). #2941
Full Changelog: v0.8.0...v0.8.1
v0.8.0: Async envs and better weight update API
TorchRL v0.8.0: Async envs and better weight update API
- Async environments: #2864 introduces asynchronous environments, which can be built using different backends (currently
"threading"or"multiprocessing"). Instantiating an async env is roughly the same as a parallel one:These environments support the regular environment methods (from torchrl.envs import AsyncEnvPool env = AsyncEnvPool([partial(GymEnv, "Pendulum-v1"), partial(GymEnv, "Pendulum-v1")], backend="threading")
reset,steporrollout) but their main advantage lies
in their new async methods:In this example,s0 = env.rand_action(env.reset()) env.async_step_send(s0) # receive result = env.async_step_recv()
resultwill contain the results of the call to step for one or two environments. The environment indices
can be found in theresult['env_index']entry (the name of that key is stored inenv._env_idx_key). - Support for environments with tensorclass attributes (#2788)
- Distributed
RayReplayBuffer(#2835) - Gymnasium 1.1 compatibility (#2898): we managed to make TorchRL compatible with Gymnasium 1.1 as this version lets
users choose how to handle partial resets, which facilitates integration in the library. - VecNormV2, a new version of vecnorm which is more numerically stable and easier to handle. This can be created directly
through the usualVecNormby passing thenew_apikeyword argument. - policy factory for collectors: you can now pass a factory for your policy instead of passing the real object.
Given that the collector will update the weights of the policy when asked to, this will in most cases not cause any
synchronization problem with the copy that is used by the training pipeline. - An Update API for policy weights in collector: we have isolated the weight update API in a
torchrl.collectors.WeightUpdaterBase
abstract class. This should the entry point for any user wanting to implement their own weight update strategy, alleviating
the need to subclass or patch the collector or the policy directly.
Packaging
We relaxed TorchRL dependency to make it compatible with any pytorch version. The current status is:
- tensordict dependency will from now on be enforced (>=0.8.1,<0.9.0 for this release)
- For PyTorch prior to 2.7.0, backward compatibility is guaranteed to some extend (most classes should work, unless new features are used) but C++ binaries (for prioritized replay buffers) will not work.
- For PyTorch >= 2.7.0, C++ binaries should work across versions. In other words, torchrl binaries for 0.8.0 will work with PyTorch 2.7.0, 2.8.0 etc., and the same goes for the future TorchRL 0.9.0... A big thanks to @janeyx99 for enabling this!
New features
[Feature] Add EnvBase.all_actions (#2780) (67c3e9a) by @kurtamohler ghstack-source-id: 7abf9d469f740be5f14daffa2330811f7572dad9
[Feature] Add MCTSForest/Tree.to_string (#2794) (f862669) by @kurtamohler ghstack-source-id: 2127bf24d66e44fb310d12ff5f72e92aa0371cd7
[Feature] Add include_hash_inv arg to ChessEnv (#2766) (3be85c6) by @kurtamohler ghstack-source-id: f6920d781835902a6db02f74c5e5a3041243c5e3
[Feature] Add option for auto-resetting envs in GAE (#2851) (f5f3ae4) by @lin-erica Co-authored-by: Erica Lin [email protected]
[Feature] Async environments (#2864) (4f00025) by @vmoens ghstack-source-id: 0a70ce0129d2ee6f85bb22adda3c332ff65e7501
[Feature] Capture wrong spec transforms (1/N) (#2805) (d3dca73) by @vmoens ghstack-source-id: f2d938b3dfe88af66622099f60cd7e3026289a02
[Feature] Collectors for async envs (#2893) (4ba5066) by @vmoens ghstack-source-id: 764c21d0f2c3b217440e1a6f12ee797b17820c1d
[Feature] DensifyReward postproc (#2823) (53065cf) by @vmoens ghstack-source-id: ef6a0f52601642c8944f63f9e3ac9e963425734e
[Feature] Dynamic specs for make_composite_from_td (#2829) (413571b) by @vmoens ghstack-source-id: 79e31e737c9f67ff20ce9fe32081e5b0a83de947
[Feature] Enable Hash.inv (#2757) (32c4623) by @kurtamohler ghstack-source-id: 956708121067855e519382a37764f06f53b16aa7
[Feature] Env with tensorclass attributes (#2788) (ab76027) by @vmoens ghstack-source-id: dc00ea3d23e015756974cd5c2ce638b55e5f6f92
[Feature] Gymnasium 1.1 compatibility (#2898) (78cd755) by @vmoens ghstack-source-id: e0891867f4318380f01c15449f9f26070b78536d
[Feature] History API (#2890) (fd10fe2) by @vmoens ghstack-source-id: 5b9723f6e1c327625e1a9be6f6eac68b91ed8492
[Feature] History.default_spec (#2894) (8ce11a8) by @vmoens ghstack-source-id: 40b8a492765a85adaccb591f1bc173754bacc313
[Feature] Local and Remote WeightUpdaters (#2848) (27d3680) by @vmoens ghstack-source-id: 2962530f87b596d038e3a13a934ea09064af2964
[Feature] Make PPO ready for text-based data (#2857) (595ddb4) by @vmoens ghstack-source-id: eeda5e2355e573e74cf7c080994cd47520ecd45b
[Feature] MultiAction transform (#2779) (621776a) by @vmoens ghstack-source-id: 0a6f7f916ee6f9c6d450c511385bdfdb1d911da0
[Feature] NonTensor batched arg (#2816) (b97bdb5) by @vmoens ghstack-source-id: c6de1bd1f1475b8d02df2ff3eb7438a50f2ae450
[Feature] Pass lists of policy_factory (#2888) (82f8ec2) by @vmoens ghstack-source-id: e42b100096c6e38365f8a80681473746f51d8a77
[Feature] RayReplayBuffer (#2835) (50af984) by @vmoens ghstack-source-id: 32eff06494037a1a30e532539794035c035f1e81
[Feature] Set padded token log-prob to 0.0 (#2856) (b9ddfa9) by @vmoens ghstack-source-id: 2b2993e0b15afae17326e6583390d57068712d4f
[Feature] Support lazy tensordict inputs in ppo loss (#2883) (c9caf3d) by @vmoens ghstack-source-id: 89098ba3ca61b1524aeddc68f54c377f29c8dc8b
[Feature] TensorDictPrimer with single default_value callable (#2732) (59e8545) by @vmoens ghstack-source-id: a9a677f24fc1e6a47312d0a96ab60daae543ff78
[Feature] Timer transform (#2806) (104b880) by @vmoens ghstack-source-id: e42f2aece15f90afc457e1fb3e41a1f7be1a6a85
[Feature] Transform for partial steps (#2777) (7c034e3) by @vmoens ghstack-source-id: 587f91e33dfe1d59b73c4b2f2f1c21760ee79d2e
[Feature] VecNormV2 (#2867) (40fcdb6) by @vmoens ghstack-source-id: 639d07ff54be200d54621c2c4619ebd0d3d7d79e
[Feature] VecNormV2: Usage with batched envs (#2901) (b08e7ac) by @vmoens ghstack-source-id: 5e14ed982b71b0e5192b0687c5259a3b49a81157
[Feature] pass policy-factory in mp data collectors (#2859) (31af2c5) by @vmoens ghstack-source-id: bce8abe9853d5ec187f91ffbcd8b940fa18ec8ab
[Feature] policy factory for collectors (#2841) (49a8a42) by @vmoens ghstack-source-id: 96b928e938b8b07fc7de23483358202737571f8e
[Feature] reset_time in Timer (#2807) (5a46379) by @vmoens ghstack-source-id: 36a74fd20b78e1cdde6bca19b4f95c3d9062d761
[Feature] transformers policy (#2825) (eea932c) by @vmoens ghstack-source-id: 870c221b4ebae132a44944f0be0ee78da540d115
Fixes
[BugFix] Apply inverse transform to input of TransformedEnv._reset (#2787) (1ed5d29) by @kurtamohler ghstack-source-id: 5f7c1fbd19b716f2b1602c34cf2ae1362f7bc7f6
[BugFix] Avoid calling reset during env init (#2770) (09e93c1) by @vmoens ghstack-source-id: 5ab8281c34aacfd7dbbfc0e285d88bcae0aededf
[BugFix] Ensure that Composite.set returns self as TensorDict does (#2784) (e084c02) by @vmoens ghstack-source-id: 23fe46b61dc2c9548fd9de7e4100431918fd0370
[BugFix] Fix .item() warning on tensors that require grad (#2885) (b66fcd4) by @vmoens ghstack-source-id: 502bdda3f5700dc900cf5c748839c965b1d67c1b
[BugFix] Fix KL penalty (#2908) (96c3003) by @vmoens ghstack-source-id: 475dccb0bcddbfe3bd2d826c5389834fb95e1ab8
[BugFix] Fix MultiAction reset (#2789) (76aa9bc) by @kurtamohler ghstack-source-id: a2f7bfdd7522a214430182dac65687a977b1a10d
[BugFix] Fix PEnv device copies (#2840) (6e40548) by @vmoens ghstack-source-id: df39fd2e4cd72f24c645b0ac32b46ab3e8d847fc
[BugFix] Fix batch_locked check in check_env_specs + error message callable (#2817) (9c98b82) by @vmoens ghstack-source-id: c722b164133c27c05dd21add3e7f3158189dd515
[BugFix] Fix calls to _reset_env_preprocess (#2798) (ea76ffb) by @vmoens ghstack-source-id: 59925635a87b196a5bcb0fb251afe4cc7b8b103e
[BugFix] Fix collector timeouts (#2774) (f6084b6) by @vmoens ghstack-source-id: cb71d95143beb22db1fe1752e72f70c19f43be79
[BugFix] Fix collector with no buffers and devices (#2809) (d4f8846) by @vmoens ghstack-source-id: 5367df9fcfdf549108be852476b049a0b978e348
[BugFix] Fix compile compatibility of PPO losses (#2889) (9bc85f4) by @vmoens ghstack-source-id: b346033641e5d27560fbfa011a006446e56a4e31
[BugFix] Fix composite setitem (#2778) (c2a149d) by @vmoens ghstack-source-id: f33b49beb4cf8c0c8b156559b1abbee8ac77db20
[BugFix] Fix env.full_done_specs (#2815) (f5c0666) by @vmoens ghstack-source-id: ba0d371d10b3f46ec1172fbec639ccc4d5559659
[BugFix] Fix forced batch-size in _skip_tensordict (#2808) (3acf491) by @vmoens ghstack-source-id: dac84e8b8835e870bce1772d7893c30b6f9af59c
[BugFix] Fix gc import (#2862) (a183f02) by @vmoens ghstack-source-id: b732d4f805d98ceaaa45326d619fce623c10482f
[BugFix] Fix lazy-stack in RBs (#2880) (e80732e) by @vmoens ghstack-source-id: 38399ee991bc065445f4eb1c84b71e7d844d794c
[BugFix] Fix property getter in RayReplayBuffer (#2869) (04d70c1) by @vmoens
[BugFix] Fix slow and flaky non-tensor parallel env test (#2926) by @vmoens ghstack-source-id: fcb5caa56e05176958b3468a7d6f69e363cfe558
[BugFix] Fix update shape mismatch in _skip_tensordict (#2792) (3e42e7a) by @vmoens ghstack-source-id: 27e7d444c126e48fdb70d951a0cc7beaee1db3a8
[BugFix] Fixed VideoRecorder crash when passing fps (#2827) (5ec9bc5) by Alexandre Brown
[BugFix] GAE warning when gamma/lmbda are tensors (#2838) (d561115) by @louisfaury Co-authored-by: Louis Faury [email protected]
[BugFix] Keep original class in LazyStackStorage through lazy_stack (#2873) (70f5c06) by @vmoens ghstack-source-id: 661cd65c86648ffb2ee6ead40110ac3d57477514
[BugFix] Non...
0.7.2: ParallelEnv fix
We are releasing TorchRL 0.7.2, a minor update that addresses several important bug fixes to improve the stability and reliability of our library.
This release is particularly crucial as it resolves a critical issue (#2840) where, under certain conditions, the device setting of the parallel environment would prevent the tensors in the buffers from being properly cloned. This resulted in rollouts returning the same tensor instances across steps, potentially leading to incorrect behavior and results.
Due to the severity of this bug, we strongly recommend that all users upgrade to TorchRL 0.7.2 to ensure the accuracy and reliability of their experiments.
The full list of changes can be found below:
- [Doc] Fix formatting errors by @vmoens (#2786)
- [BugFix] correct dim for resolving dtype in _split_and_pad_sequence by @KubaMichalczyk and vmoens (#2801)
- [BugFix] Fix collector with no buffers and devices by @vmoens (#2809)
- [BE] Fix some typos by antoinebrl and @vmoens (#2811)
- [Doc] Add docstring for MCTSForest.extend by @kurtamohler and @vmoens (#2795)
- [CI] Fix libs workflows by @vmoens (#2800)
- [BugFix] Fix env.full_done_spec
sby @vmoens (#2815) - [BugFix] Fix batch_locked check in check_env_specs + error message ca… by @vmoens (#2817)
- [BugFix] GAE warning when gamma/lmbda are tensors by louisfaury and @vmoens (#2838)
- [BugFix] Tree make node fix by rolo and @vmoens (#2839)
- [BugFix] Fix PEnv device copies by @vmoens (#2840)
Full Changelog: v0.7.1...v0.7.2
0.7.1: Bug fixes and documentation improvements
We are pleased to announce the release of torchrl v0.7.1, which includes several bug fixes, documentation updates, and backend improvements.
Bug Fixes
- Fixed collector timeouts (#2774)
- Fixed composite setitem (#2778)
- Ensured that Composite.set returns self as TensorDict does (#2784)
- Fixed PPOs with composite distribution (#2791)
- Used brackets to get non-tensor data in gym envs (#2769)
- Avoided calling reset during env init (#2770)
- NonTensor should not convert anything to numpy (#2771)
Documentation Updates:
Backend Improvements:
- Made better logits in cost tests (#2775)
- Ensured abstractmethods are implemented for specs (#2790)
- Removed deprec specs from tests (#2767)
Thank you to @antoinebrl, and @louisfaury for contributing to this release!
Full Changelog: v0.7.0...v0.7.1
0.7.0: Compile compatibility, Chess and better multi-head policies
As always, we want to warmly thank the RL community who's supporting this project. A special thanks to our first time
contributors:
- @priba made their first contribution in #2543
- @carschandler made their first contribution in #2545
- @4d616e61 made their first contribution in #2624
- @valterschutz made their first contribution in #2626
- @raresdan made their first contribution in #2616
- @oslumbers made their first contribution in #2609
- @codingWhale13 made their first contribution in #2682
as well as all the users who wrote issues, suggestions, started discussions here, on discord,
on the pytorch forum or elsewhere! We value your feedback!
BC-Breaking changes and Deprecated behaviors
Removed classes
As announced, we removed the following classes:
- AdditiveGaussianWrapper
- InPlaceSampler
- NormalParamWrapper
- OrnsteinUhlenbeckProcessWrapper
Default MLP config
The default MLP depth has passed from 3 to 0 (i.e., now MLP(in_features=3, out_features=4) is equivalent to a regular
nn.Linear layer).
Locking envs
Environments specs are now carefully locked by default (#2729, #2730). This means that
env.observation_spec = specis allowed (specs will be unlocked/re-locked automatically) but
env.observation_spec["value"] = specwill not work. The core idea here is that we want to cache as much info as we can, such as action keys or whether
the env has dynamic specs. We can only do that if we can guarantee that the env has not been modified. Locking the specs
provides us such guarantee.
Note that a version of this already existed but it was not as robust as the new one.
Changes to composite distributions
TL;DR: We're changing the way log-probs and entropies are collected and written in ProbabilisticTensorDictModule and
in CompositeDistribution. The "sample_log_prob" default key will soon be "<value>_log_prob (or
("path", "to", "<value>_log_prob") for nested keys). For CompositeDistribution, a different log-prob will be
written for each leaf tensor in the distribution. This new behavior is controlled by the
tensordict.nn.set_composite_lp_aggregate(mode: bool) function or by the COMPOSITE_LP_AGGREGATE environment variable.
We strongly encourage users to adopt the new behavior by setting tensordict.nn.set_composite_lp_aggregate(False).set()
at the beginning of their training script.
The behavior of CompositeDistribution and its interaction with on-policy losses such as PPO has changed.
The PPO documentation now includes a section about multi-head policies and the examples also give such information.
See the tensordict v0.7.0 release notes or #2707 to know more.
[Deprecation] Change the default MLP depth (#2746) (12e6bce) by @vmoens ghstack-source-id: bd34b8e9112c4fc3a30bd095e3ac073a7d0b5469
[Deprecation] Gracing old *Spec with v0.8 versioning (#2751) (fa697fe) by @vmoens ghstack-source-id: e7c6e0a4b8520da887fe7e602a351c3c72a08c4c
[Deprecation] Remove AdditiveGaussianWrapper (#2748) (6c7f4fb) by @vmoens ghstack-source-id: 78f248e1239a04fc5213aa4418a158f741679593
[Deprecation] Remove InPlaceSampler (#2750) (0feef11) by @vmoens ghstack-source-id: eeae1bf0611a5d293f533767eee7b9700e720cc8
[Deprecation] Remove NormalParamWrapper (#2747) (a38604e) by @vmoens ghstack-source-id: 4a70178f54f9e25d602c86a0b61248d66f3e39bd
[Deprecation] Remove OrnsteinUhlenbeckProcessWrapper (#2749) (0111a87) by @vmoens ghstack-source-id: 401fdfaca2e27122d5a67fc7177e1015047f0098
New features
Compile compatibility
We gave a strong focus on a better compatibility with torch.compile across the SOTA training scripts which now
all accept a compile=1 argument. The overall speedups range from 1 to 4x
Loss module speedups are displayed in the README.md page.
Replay buffers are also mostly compatible with compile now (with the notable exception of distributed and memmaped ones).
Specs: auto_spec_, <attr>_spec_unbatched
You can now use env.auto_spec_ to set the specs automatically based on a dummy rollout.
For batched environments, the unbatched spec can now be accessed via env.<attr>_spec_unbatched. This is useful to
create random policies, for example.
New transforms
We added TrajCounter (#2532), Hash and Tokenizer (#2648, #2700) and LineariseReward (#2681).
LazyStackStorage
We provide a new ListStorage-based storage (LazyStackStorage) that automatically represents samples as a LazyStackedTensorDict
which makes it easy to store ragged tensors (although not contiguously in memory) #2723.
ChessEnv
A new torchrl.envs.ChessEnv allows users to train agents to play chess!
Tutorials on exporting torchrl modules
We also opensourced a tutorial to export TorchRL modules on hardware: #2557
Full list of features
[Feature, Test] Adding tests for envs that have no specs (#2621) (c72583f) by @vmoens ghstack-source-id: 4c75691baa1e70f417e518df15c4208cff189950
[Feature,Refactor] Chess improvements: fen, pgn, pixels, san, action mask (#2702) (d425777) by @vmoens ghstack-source-id: f294a2bc99a17911c9b62558d530b148d3c0350f
[Feature] A2C compatibility with compile (#2464) (507766a) by @vmoens ghstack-source-id: 66a7f0d1dd82d6463d61c1671e8e0a14ac9a55e7
[Feature] ActionDiscretizer custom sampling (#2609) (3da76f0) @oslumbers Co-authored-by: Oliver Slumbers [email protected]
[Feature] Add Hash transform (#2648) (50011dc) @kurtamohler ghstack-source-id: dccf63fe4f9d5f76947ddb7d5dedcff87ff8cdc5
[Feature] Add Choice spec (#2713) (9368ca6) @kurtamohler ghstack-source-id: afa315a311845ab39ade3e75046f32757f9d94f1
[Feature] Add LossModule.reset_parameters_recursive (#2546) (218d5bf) by @kurtamohler
[Feature] Add Stack transform (#2567) (594462d) by @kurtamohler
[Feature] Add deterministic_sample to masked categorical (#2708) (49d9897) by @vmoens ghstack-source-id: d34fcf9b44d7a7c60dbde80b0835189f990ef226
[Feature] Adds ordinal distributions (#2520) (c851e16) by @louisfaury Co-authored-by: @louisfaury
[Feature] Avoid some recompiles of ReplayBuffer.extend/sample (#2504) (0f29c7e) @kurtamohler
[Feature] CQL compatibility with compile (#2553) (e2be42e) by @vmoens ghstack-source-id: d362d6c17faa0eb609009bce004bb4766e345d5e
[Feature] CROSSQ compatibility with compile (#2554) (01a421e) by @vmoens ghstack-source-id: 98a2b30e8f6a1b0bc583a9f3c51adc2634eb8028
[Feature] CatFrames.make_rb_transform_and_sampler (#2643) (9ee1ae7) by @vmoens ghstack-source-id: 7ecf952ec9f102a831aefdba533027ff8c4c29cc
[Feature] ChessEnv (#2641) (17983d4) by @vmoens ghstack-source-id: 087c3b12cd621ea11a252b34c4896133697bce1a
[Feature] Composite.batch_size (#2597) (2e82cab) by @vmoens ghstack-source-id: 621884a559a71e80a4be36c7ba984fd08be47952
[Feature] Composite.pop (#2598) (8d16c12) by @vmoens ghstack-source-id: 64d5bd736657ef56e37d57726dfcfd25b16b699f
[Feature] Composite.separates (#2599) (83e0b05) by @vmoens ghstack-source-id: fbfc4308a81cd96ecc61723df8c0eb870c442def
[Feature] Custom conversion tool for gym specs (#2726) (dbc8e2e) by @vmoens ghstack-source-id: d38bb02f15267a9b1637b3ed25fb44ef013e2456
[Feature] DDPG compatibility with compile (#2555) (7d7cd95) by @vmoens ghstack-source-id: f18928a419f81794d6870fd4e9fe1205c1b137e1
[Feature] DQN compatibility with compile (#2571) (f149811) by @vmoens ghstack-source-id: 113dc8c4a5562d217ed867ace1942b2f6b8a39f9
[Feature] DT compatibility with compile (#2556) (fbfe104) by @vmoens ghstack-source-id: 362b6e88bad4397f35036391729e58f4f7e4a25d
[Feature] Discrete SAC compatibility with compile (#2569) (9e2d214) by @vmoens ghstack-source-id: ddc131acedbbe451b28758e757a8c240ebd72b80
[Feature] Ensure out-place policy compatibility in rollout and collectors (#2717) (ec370c6) by @vmoens ghstack-source-id: 41a6aa56e0a045a20224b96f9537a7ae3ae14494
[Feature] EnvBase.auto_specs_ (#2601) (d537dcb) by @vmoens ghstack-source-id: 329679238c5172d7ff13097ceaa189479d4f4145
[Feature] EnvBase.check_env_specs (#2600) (00d3199) by @vmoens ghstack-source-id: 332dbf92db496c71c5ce6aba340ad123eac0f5d6
[Feature] GAIL compatibility with compile (#2573) (6482766) by @vmoens ghstack-source-id: 98c7602ec0343d7a83cb19bddeb579484c42e77e
[Feature] IQL compatibility with compile (#2649) (2cfc2ab) by @vmoens ghstack-source-id: 77bca166701d28dd69ef3964f55ab4f3e4b17fed
[Feature] LLMHashingEnv (#2635) (30d21e5) by @vmoens ghstack-source-id: d1a20ecd023008683cf18cf9e694340cfdbdac8a
[Feature] LazyStackStorage (#2723) (fe3f00c) by @vmoens ghstack-source-id: e9c031470aa0bdafbb2b26c73c06b25685a128e5
[Feature] Linearise reward transform (#2681) (ff1ff7e) by @louisfaury Co-authored-by: @louisfaury
[Feature] Log each entropy for composite distributions in PPO (#2707) (319bb68) by @louisfaury Co-authored-by: @louisfaury
[Feature] Log pbar rate in SOTA implementations (#2662) (1ce25f1) by @vmoens ghstack-source-id: 283cc1bb4ad2d60281296d2cfb78ec41c77f4129
[Feature] MCTSForest (#2307) (e9d1677) by @vmoens ghstack-source-id: 9ac5cd3de39a4dbe1c7c33cb71ff6f45a886ae65
[Feature] Make PPO compatible with composite actions and log-probs (#2665) (256a700) by @vmoens ghstack-source-id: c41718e697f9b6edda17d4ddb5bd6d41402b7c30
[Feature] PPO compatibility with compile (#2652) (f5a187d) by @vmoens ghstack...