diff --git a/src/lerobot/processor/tokenizer_processor.py b/src/lerobot/processor/tokenizer_processor.py index df559555a7f..08ff3bc1580 100644 --- a/src/lerobot/processor/tokenizer_processor.py +++ b/src/lerobot/processor/tokenizer_processor.py @@ -392,6 +392,17 @@ def __post_init__(self): add_bos_token=False, ) + # Cache constant token sequences to avoid recomputing them every iteration. + self._prefix_tokens = torch.tensor( + [self._paligemma_tokenizer.bos_token_id] + + self._paligemma_tokenizer.encode("Action: ", add_special_tokens=False), + dtype=torch.long, + ) + self._suffix_tokens = torch.tensor( + self._paligemma_tokenizer.encode("|"), + dtype=torch.long, + ) + def __call__(self, transition: EnvTransition) -> EnvTransition: """ Applies action tokenization to the transition. @@ -445,7 +456,7 @@ def _tokenize_action(self, action: torch.Tensor) -> tuple[torch.Tensor, torch.Te if action is None: raise ValueError("Action cannot be None") - # Get the device and dtype of the input action + # Get the device of the input action device = action.device if isinstance(action, torch.Tensor) else None # Handle single sample (add batch dimension) @@ -455,38 +466,34 @@ def _tokenize_action(self, action: torch.Tensor) -> tuple[torch.Tensor, torch.Te batch_size = action.shape[0] - # Tokenize the action batch - # The fast tokenizer expects action data and returns token IDs + # Move entire batch to CPU once to avoid per-sample GPU-CPU synchronization. + # The tokenizer uses scipy/numpy internally and requires CPU tensors. + action_cpu = action.cpu() + + # Tokenize the action batch. All work is done on CPU; the result is + # transferred to the target device in a single bulk operation at the end. tokens_list = [] masks_list = [] for i in range(batch_size): - # Tokenize single action (move to CPU first as tokenizer uses scipy which requires numpy) - action_cpu = action[i : i + 1].cpu() - tokens = self.action_tokenizer(action_cpu) + tokens = self.action_tokenizer(action_cpu[i : i + 1]) - # Convert to numpy array if it's a list if isinstance(tokens, list) or not isinstance(tokens, torch.Tensor): - tokens = torch.tensor(tokens, dtype=torch.long, device=action.device) + tokens = torch.tensor(tokens, dtype=torch.long) else: - # Move tokens back to the same device as input action - tokens = tokens.to(device=action.device) + tokens = tokens.to(dtype=torch.long) # Flatten to 1D if needed if tokens.dim() > 1: tokens = tokens.flatten() - bos_id = self._paligemma_tokenizer.bos_token_id - # add bos + # Prepend prefix (bos + "Action: ") and append suffix ("|"), + # using cached constant tensors. tokens = torch.cat( [ - torch.tensor([bos_id], device=action.device), - torch.tensor( - self._paligemma_tokenizer.encode("Action: ", add_special_tokens=False), - device=action.device, - ), + self._prefix_tokens, self._act_tokens_to_paligemma_tokens(tokens), - torch.tensor(self._paligemma_tokenizer.encode("|"), device=action.device), + self._suffix_tokens, ] ) @@ -497,23 +504,16 @@ def _tokenize_action(self, action: torch.Tensor) -> tuple[torch.Tensor, torch.Te "Consider increasing the `max_action_tokens` in your model config if this happens frequently." ) tokens = tokens[: self.max_action_tokens] - mask = torch.ones(self.max_action_tokens, dtype=torch.bool, device=action.device) + mask = torch.ones(self.max_action_tokens, dtype=torch.bool) else: - mask = torch.cat( - [ - torch.ones(len(tokens), dtype=torch.bool, device=action.device), - torch.zeros( - self.max_action_tokens - len(tokens), dtype=torch.bool, device=action.device - ), - ] - ) - # Pad tokens with zeros + mask = torch.zeros(self.max_action_tokens, dtype=torch.bool) + mask[: len(tokens)] = True tokens = torch.nn.functional.pad(tokens, (0, self.max_action_tokens - len(tokens)), value=0) tokens_list.append(tokens) masks_list.append(mask) - # Stack into batched tensors + # Stack into batched tensors and move to the target device in one transfer. tokens_batch = torch.stack(tokens_list, dim=0) # (B, max_action_tokens) masks_batch = torch.stack(masks_list, dim=0) # (B, max_action_tokens) @@ -522,7 +522,6 @@ def _tokenize_action(self, action: torch.Tensor) -> tuple[torch.Tensor, torch.Te tokens_batch = tokens_batch.squeeze(0) masks_batch = masks_batch.squeeze(0) - # Move to the same device as the input if device is not None: tokens_batch = tokens_batch.to(device) masks_batch = masks_batch.to(device)