Skip to content
Open
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
59 changes: 29 additions & 30 deletions src/lerobot/processor/tokenizer_processor.py
Original file line number Diff line number Diff line change
Expand Up @@ -392,6 +392,17 @@ def __post_init__(self):
add_bos_token=False,
)

# Cache constant token sequences to avoid recomputing them every iteration.
self._prefix_tokens = torch.tensor(
[self._paligemma_tokenizer.bos_token_id]
+ self._paligemma_tokenizer.encode("Action: ", add_special_tokens=False),
dtype=torch.long,
)
self._suffix_tokens = torch.tensor(
self._paligemma_tokenizer.encode("|"),
dtype=torch.long,
)

def __call__(self, transition: EnvTransition) -> EnvTransition:
"""
Applies action tokenization to the transition.
Expand Down Expand Up @@ -445,7 +456,7 @@ def _tokenize_action(self, action: torch.Tensor) -> tuple[torch.Tensor, torch.Te
if action is None:
raise ValueError("Action cannot be None")

# Get the device and dtype of the input action
# Get the device of the input action
device = action.device if isinstance(action, torch.Tensor) else None

# Handle single sample (add batch dimension)
Expand All @@ -455,38 +466,34 @@ def _tokenize_action(self, action: torch.Tensor) -> tuple[torch.Tensor, torch.Te

batch_size = action.shape[0]

# Tokenize the action batch
# The fast tokenizer expects action data and returns token IDs
# Move entire batch to CPU once to avoid per-sample GPU-CPU synchronization.
# The tokenizer uses scipy/numpy internally and requires CPU tensors.
action_cpu = action.cpu()

# Tokenize the action batch. All work is done on CPU; the result is
# transferred to the target device in a single bulk operation at the end.
tokens_list = []
masks_list = []

for i in range(batch_size):
# Tokenize single action (move to CPU first as tokenizer uses scipy which requires numpy)
action_cpu = action[i : i + 1].cpu()
tokens = self.action_tokenizer(action_cpu)
tokens = self.action_tokenizer(action_cpu[i : i + 1])

# Convert to numpy array if it's a list
if isinstance(tokens, list) or not isinstance(tokens, torch.Tensor):
tokens = torch.tensor(tokens, dtype=torch.long, device=action.device)
tokens = torch.tensor(tokens, dtype=torch.long)
else:
# Move tokens back to the same device as input action
tokens = tokens.to(device=action.device)
tokens = tokens.to(dtype=torch.long)

# Flatten to 1D if needed
if tokens.dim() > 1:
tokens = tokens.flatten()

bos_id = self._paligemma_tokenizer.bos_token_id
# add bos
# Prepend prefix (bos + "Action: ") and append suffix ("|"),
# using cached constant tensors.
tokens = torch.cat(
[
torch.tensor([bos_id], device=action.device),
torch.tensor(
self._paligemma_tokenizer.encode("Action: ", add_special_tokens=False),
device=action.device,
),
self._prefix_tokens,
self._act_tokens_to_paligemma_tokens(tokens),
torch.tensor(self._paligemma_tokenizer.encode("|"), device=action.device),
self._suffix_tokens,
]
)

Expand All @@ -497,23 +504,16 @@ def _tokenize_action(self, action: torch.Tensor) -> tuple[torch.Tensor, torch.Te
"Consider increasing the `max_action_tokens` in your model config if this happens frequently."
)
tokens = tokens[: self.max_action_tokens]
mask = torch.ones(self.max_action_tokens, dtype=torch.bool, device=action.device)
mask = torch.ones(self.max_action_tokens, dtype=torch.bool)
else:
mask = torch.cat(
[
torch.ones(len(tokens), dtype=torch.bool, device=action.device),
torch.zeros(
self.max_action_tokens - len(tokens), dtype=torch.bool, device=action.device
),
]
)
# Pad tokens with zeros
mask = torch.zeros(self.max_action_tokens, dtype=torch.bool)
mask[: len(tokens)] = True
tokens = torch.nn.functional.pad(tokens, (0, self.max_action_tokens - len(tokens)), value=0)

tokens_list.append(tokens)
masks_list.append(mask)

# Stack into batched tensors
# Stack into batched tensors and move to the target device in one transfer.
tokens_batch = torch.stack(tokens_list, dim=0) # (B, max_action_tokens)
masks_batch = torch.stack(masks_list, dim=0) # (B, max_action_tokens)

Expand All @@ -522,7 +522,6 @@ def _tokenize_action(self, action: torch.Tensor) -> tuple[torch.Tensor, torch.Te
tokens_batch = tokens_batch.squeeze(0)
masks_batch = masks_batch.squeeze(0)

# Move to the same device as the input
if device is not None:
tokens_batch = tokens_batch.to(device)
masks_batch = masks_batch.to(device)
Expand Down