Skip to content

Bringup Pi_0 Pytorch model#476

Merged
ashokkumarkannan1 merged 3 commits intomainfrom
akannan/pi_0_bringup
Feb 13, 2026
Merged

Bringup Pi_0 Pytorch model#476
ashokkumarkannan1 merged 3 commits intomainfrom
akannan/pi_0_bringup

Conversation

@ashokkumarkannan1
Copy link
Contributor

@ashokkumarkannan1 ashokkumarkannan1 commented Feb 11, 2026

Ticket

Link to Github Issue - tenstorrent/tt-xla#2979

Problem description

to bringup Pi_0 model

What's changed

added Pi_0 pytorch model

Checklist

  • New/Existing tests provide coverage for changes

Logs

@ashokkumarkannan1 ashokkumarkannan1 changed the title bringup Pi_0 Pytorch model Bringup Pi_0 Pytorch model Feb 11, 2026


@torch.no_grad()
def forward(
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Can you explain why do we need to override the forward method in this way?

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

  • The select_action method is responsible for inference (as described in source). The Policy.forward method in LeRobot is only used during training to compute the loss and is not part of the inference path.

  • Since our use case targets inference, we needed to override forward to incorporate the select_action logic, ensuring the correct execution flow during model inference. (as already mentioned in the forward method docstring)

  • Additionally, the steps inside preprocess_for_sampling, which are normally invoked within select_action, have been separated and applied directly during the load_inputs stage. This is because those steps are specifically responsible for preprocessing a batch before action sampling.

return self._action_queue.popleft()


PI0Policy.preprocess_for_sampling = preprocess_for_sampling
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

can you avoid global module overriding function like this. You can make a function like get_custom_pi0_policy that returns this object instead of it being global

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

  • As suggested, I’ve added a get_custom_pi0_policy factory function. This function loads the original PI0Policy, overrides the forward and preprocess_for_sampling methods on that instance, and returns the modified policy object.

  • This avoids any global module-level overrides and keeps the changes scoped to the specific instance being used.

torch.nn.Module: The Pi-0 Policy instance.
"""
from lerobot.policies.pi0 import PI0Policy
from .src import model
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

due to the global PI0Policy object I mentioned earlier, it is not clear whether you are using the original PI0Policy object, or is it getting overwritten by the one from .src.model? Please change to from .src.model import get_custom_pi0_policy and then I assume this from lerobot.policies.pi0 import PI0Policy import will be unnessesary

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

  • I’ve updated the load_model function to instantiate the model using get_custom_pi0_policy, which internally handles importing the original PI0Policy and applying the required overrides.

  • As a result, the direct import from lerobot.policies.pi0 import PI0Policy has been removed from this file to avoid any confusion about which policy object is being used.

@@ -0,0 +1,7 @@
transformers @ git+https://github.com/huggingface/transformers.git@fix/lerobot_openpi
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Is this the standard way to do this? I did a quick google search and I found this pip dependency for PI0:

'For lerobot 0.4.0, if you want to install pi tag, you will have to do: pip install "lerobot[pi]@git+https://github.com/huggingface/lerobot.git".'

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

  • this pip install "lerobot[pi]@git+[https://github.com/huggingface/lerobot.git installs LeRobot’s default dependencies include versions of torch(<2.8) and other related ones that conflict with our target XLA environment.
  • Because of that, we install LeRobot itself with --no-deps, then install only the required dependencies manually (as listed in our requirements.txt), ensuring compatibility with the current environment.

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Ok tnx, makes sense

@ashokkumarkannan1 ashokkumarkannan1 force-pushed the akannan/pi_0_bringup branch 2 times, most recently from 408a21e to db5bc7e Compare February 12, 2026 12:04
Copy link
Contributor

@ppadjinTT ppadjinTT left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Thanks!

@ashokkumarkannan1 ashokkumarkannan1 merged commit 409fa10 into main Feb 13, 2026
2 checks passed
@ashokkumarkannan1 ashokkumarkannan1 deleted the akannan/pi_0_bringup branch February 13, 2026 13:15
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Labels

None yet

Projects

None yet

Development

Successfully merging this pull request may close these issues.

3 participants