-
Notifications
You must be signed in to change notification settings - Fork 3.3k
Add sarm #2639
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
base: main
Are you sure you want to change the base?
Add sarm #2639
Conversation
* Add generate and validate script * fix precommit * Improve generate embeddings function by using dataset tools (#2206) --------- Co-authored-by: Michel Aractingi <[email protected]>
|
The docs for this PR live here. All of your documentation changes will be reflected on that endpoint. The docs are available until 30 days after the last update. |
| </hfoption> | ||
| <hfoption id="dual"> | ||
|
|
||
| Visualize annotations using the `--visualize-only` flag: |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
We can remove this example,or we say this command to show both annotation because it looks similar to the above
|
|
||
| | Argument | Description | | ||
| | ---------------------- | -------------------------------------------------------- | | ||
| | `--visualize-only` | Only visualize predictions (no RABC computation) | |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Change RABC to RA-BC for consistency
|
|
||
| # RA-BC (Reward-Aligned Behavior Cloning) parameters | ||
| use_rabc: bool = False # Enable reward-weighted training | ||
| rabc_progress_path: str | None = None # Path to precomputed SARM progress parquet file |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Is it possible to infer the SARM progress from the dataset?
| "ninja>=1.11.1,<2.0.0", | ||
| "flash-attn>=2.5.9,<3.0.0 ; sys_platform != 'darwin'" | ||
| ] | ||
| sarm = ["lerobot[transformers-dep]", "faker>=33.0.0,<35.0.0"] |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I think we need to add matplotlib for subtask_annotation.py
| @@ -0,0 +1,1221 @@ | |||
| #!/usr/bin/env python | |||
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I got memory/RAM issue running this script with Qwen/Qwen3-VL-4B-Instruct. We need to investigate why?
| from pydantic import BaseModel, Field | ||
| from qwen_vl_utils import process_vision_info | ||
| from rich.console import Console | ||
| from transformers import AutoProcessor, Qwen3VLMoeForConditionalGeneration |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
The subtask annotation can only use Qwen models family, maybe it's better to make it work with other models that maybe better than Qwen in the future.
| def __init__( | ||
| self, | ||
| config: ACTConfig, | ||
| **kwargs, |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
We need to add per-sample losses.
| def __init__( | ||
| self, | ||
| config: DiffusionConfig, | ||
| **kwargs, |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
We need to add per-sample losses here too.
| config_class = GrootConfig | ||
|
|
||
| def __init__(self, config: GrootConfig): | ||
| def __init__(self, config: GrootConfig, **kwargs): |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
We need to add per-sample losses here too.
| def __init__( | ||
| self, | ||
| config: TDMPCConfig, | ||
| **kwargs, |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
We need to add per-sample losses here too
| name = "xvla" | ||
|
|
||
| def __init__(self, config: XVLAConfig): | ||
| def __init__(self, config: XVLAConfig, **kwargs): |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
We need to add per-sample losses here too.
| def __init__( | ||
| self, | ||
| config: VQBeTConfig | None = None, | ||
| **kwargs, |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
We need to add per-sample losses here too.
| "StageTransformer", | ||
| "SubtaskTransformer", | ||
| "gen_stage_emb", | ||
| "SARMEncodingProcessorStep", |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Do we need to add those here?
| # Full RABC computation with visualizations | ||
| python src/lerobot/policies/sarm/compute_rabc_weights.py \\ | ||
| --dataset-repo-id lerobot/aloha_sim_insertion_human \\ | ||
| --reward-model-path pepijn223/sarm_single_uni4 | ||
| # Faster computation with stride (compute every 5 frames, interpolate the rest) | ||
| python src/lerobot/policies/sarm/compute_rabc_weights.py \\ | ||
| --dataset-repo-id lerobot/aloha_sim_insertion_human \\ | ||
| --reward-model-path pepijn223/sarm_single_uni4 \\ | ||
| --stride 5 | ||
| # Visualize predictions only (no RABC computation) | ||
| python src/lerobot/policies/sarm/compute_rabc_weights.py \\ | ||
| --dataset-repo-id lerobot/aloha_sim_insertion_human \\ | ||
| --reward-model-path pepijn223/sarm_single_uni4 \\ | ||
| --visualize-only \\ | ||
| --num-visualizations 5 | ||
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Nice to change RABC to RA-BC
| return img | ||
|
|
||
|
|
||
| def visualize_episode( |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Does this works the same as visualize_episode function in subtask_annotation.py
s1lent4gnt
left a comment
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
First code review looks good to me.
I want to discuss these points:
- Should we add a separate folder for Reward Models as later we will add ReWiND?
- Do we really need to use Faker lib?
- Do we need to let users to add manual dataset subtasks annotations as an option?
- Original SARM implementation uses two different optimizers for
StageTransformerandSubtaskTransformer, should we do the same?
Differences compared to original:
LeRobot:
Clamps indices to valid bounds and applies copy-padding at the beginning and end.
Samples frames in a bidirectional manner (both backward- and forward-looking).
Original:
Uses an adaptive stride when looking backward and there is not enough history.
Progress visualization
Of trained SARM model for sparse and dense annotations (sparse every frame inference, dense has stride = 30)


Inference example on failed episode, (sparse and dense) (https://huggingface.co/datasets/lerobot-data-collection/eval_pi0_fold_11-30_1) (not in training data)

TODO: