Conversation
e746aa5 to
cf0b4ce
Compare
|
|
||
|
|
||
| @torch.no_grad() | ||
| def forward( |
There was a problem hiding this comment.
Can you explain why do we need to override the forward method in this way?
There was a problem hiding this comment.
-
The select_action method is responsible for inference (as described in source). The Policy.forward method in LeRobot is only used during training to compute the loss and is not part of the inference path.
-
Since our use case targets inference, we needed to override forward to incorporate the
select_actionlogic, ensuring the correct execution flow during model inference. (as already mentioned in the forward method docstring) -
Additionally, the steps inside
preprocess_for_sampling, which are normally invoked withinselect_action, have been separated and applied directly during theload_inputsstage. This is because those steps are specifically responsible for preprocessing a batch before action sampling.
pi_0/pytorch/src/model.py
Outdated
| return self._action_queue.popleft() | ||
|
|
||
|
|
||
| PI0Policy.preprocess_for_sampling = preprocess_for_sampling |
There was a problem hiding this comment.
can you avoid global module overriding function like this. You can make a function like get_custom_pi0_policy that returns this object instead of it being global
There was a problem hiding this comment.
-
As suggested, I’ve added a
get_custom_pi0_policyfactory function. This function loads the originalPI0Policy, overrides theforwardandpreprocess_for_samplingmethods on that instance, and returns the modified policy object. -
This avoids any global module-level overrides and keeps the changes scoped to the specific instance being used.
pi_0/pytorch/loader.py
Outdated
| torch.nn.Module: The Pi-0 Policy instance. | ||
| """ | ||
| from lerobot.policies.pi0 import PI0Policy | ||
| from .src import model |
There was a problem hiding this comment.
due to the global PI0Policy object I mentioned earlier, it is not clear whether you are using the original PI0Policy object, or is it getting overwritten by the one from .src.model? Please change to from .src.model import get_custom_pi0_policy and then I assume this from lerobot.policies.pi0 import PI0Policy import will be unnessesary
There was a problem hiding this comment.
-
I’ve updated the
load_modelfunction to instantiate the model usingget_custom_pi0_policy, which internally handles importing the originalPI0Policyand applying the required overrides. -
As a result, the direct import
from lerobot.policies.pi0 import PI0Policyhas been removed from this file to avoid any confusion about which policy object is being used.
| @@ -0,0 +1,7 @@ | |||
| transformers @ git+https://github.com/huggingface/transformers.git@fix/lerobot_openpi | |||
There was a problem hiding this comment.
Is this the standard way to do this? I did a quick google search and I found this pip dependency for PI0:
'For lerobot 0.4.0, if you want to install pi tag, you will have to do: pip install "lerobot[pi]@git+https://github.com/huggingface/lerobot.git".'
There was a problem hiding this comment.
- this
pip install "lerobot[pi]@git+[https://github.com/huggingface/lerobot.gitinstalls LeRobot’s default dependencies include versions of torch(<2.8) and other related ones that conflict with our target XLA environment. - Because of that, we install LeRobot itself with
--no-deps, then install only the required dependencies manually (as listed in our requirements.txt), ensuring compatibility with the current environment.
408a21e to
db5bc7e
Compare
db5bc7e to
92e26be
Compare
92e26be to
afef356
Compare
Ticket
Link to Github Issue - tenstorrent/tt-xla#2979
Problem description
to bringup Pi_0 model
What's changed
added Pi_0 pytorch model
Checklist
Logs