|
1 | 1 | import logging |
| 2 | +import os |
2 | 3 |
|
| 4 | +import jax |
3 | 5 | import numpy as np |
| 6 | +import orbax.checkpoint as ocp |
4 | 7 | import sentencepiece |
5 | 8 | from transformers import AutoProcessor |
6 | 9 |
|
| 10 | +import openpi.models.utils.fsq_tokenizer as fsq_tokenizer |
7 | 11 | import openpi.shared.download as download |
8 | 12 |
|
9 | 13 |
|
@@ -125,3 +129,235 @@ def _act_tokens_to_paligemma_tokens(self, tokens: np.ndarray | list[int]) -> np. |
125 | 129 | if isinstance(tokens, list): |
126 | 130 | tokens = np.array(tokens) |
127 | 131 | return self._paligemma_tokenizer.vocab_size() - 1 - self._fast_skip_tokens - tokens |
| 132 | + |
| 133 | + |
| 134 | +########################################################################### |
| 135 | +## The tokenizers below are used for RoboArena baseline implementations. ## |
| 136 | +## They are *not* used for pi0-style models. ## |
| 137 | +########################################################################### |
| 138 | + |
| 139 | + |
| 140 | +class BinningTokenizer: |
| 141 | + """ |
| 142 | + Standard RT-2 / OpenVLA style binning tokenizer. |
| 143 | + """ |
| 144 | + |
| 145 | + def __init__(self, max_len: int = 256, n_bins: int = 256): |
| 146 | + self._max_len = max_len |
| 147 | + self._n_bins = n_bins |
| 148 | + |
| 149 | + # Download base PaliGemma tokenizer |
| 150 | + path = download.maybe_download("gs://big_vision/paligemma_tokenizer.model", gs={"token": "anon"}) |
| 151 | + with path.open("rb") as f: |
| 152 | + self._paligemma_tokenizer = sentencepiece.SentencePieceProcessor(model_proto=f.read()) |
| 153 | + |
| 154 | + self._fast_skip_tokens = 128 # Skip last 128 tokens in PaliGemma vocab since they are special tokens |
| 155 | + |
| 156 | + def tokenize( |
| 157 | + self, prompt: str, state: np.ndarray, actions: np.ndarray | None |
| 158 | + ) -> tuple[np.ndarray, np.ndarray, np.ndarray, np.ndarray]: |
| 159 | + """Tokenize a prompt and state into a sequence of tokens. |
| 160 | +
|
| 161 | + Args: |
| 162 | + prompt: The text prompt to tokenize. |
| 163 | + state: The state array to discretize and tokenize. |
| 164 | + actions: Must be None. Action encoding is not currently supported. |
| 165 | +
|
| 166 | + Returns: |
| 167 | + A tuple of (tokens, token_mask, ar_mask, targets). |
| 168 | +
|
| 169 | + Raises: |
| 170 | + NotImplementedError: If actions is not None. |
| 171 | + """ |
| 172 | + cleaned_text = prompt.lower().strip().replace("_", " ") |
| 173 | + |
| 174 | + # Convention: state gets discretized into 256 discrete bins (assumed range after normalization: [-1, 1]) |
| 175 | + discretized_state = np.digitize(state, bins=np.linspace(-1, 1, 256 + 1)[:-1]) - 1 |
| 176 | + |
| 177 | + # Convention: prefix includes prompt and string-representation of state, followed by ';' |
| 178 | + state_str = " ".join(map(str, discretized_state)) |
| 179 | + prefix = f"Task: {cleaned_text}, State: {state_str};\n" |
| 180 | + prefix_tokens = self._paligemma_tokenizer.encode(prefix, add_bos=True) |
| 181 | + |
| 182 | + if actions is not None: |
| 183 | + raise NotImplementedError("BinningTokenizer does not support encoding actions atm (only for inference use)") |
| 184 | + postfix_tokens = [] |
| 185 | + |
| 186 | + # Create output token sequence & masks |
| 187 | + # AR mask is 0 on prefix (bidirectional attention) and 1 on postfix (causal attention to all previous tokens) |
| 188 | + tokens = prefix_tokens + postfix_tokens |
| 189 | + token_mask = [True] * len(tokens) |
| 190 | + ar_mask = [0] * len(prefix_tokens) + [1] * len(postfix_tokens) |
| 191 | + loss_mask = [False] * len(prefix_tokens) + [True] * len(postfix_tokens) # Loss on postfix only |
| 192 | + |
| 193 | + # Pad tokens to max length |
| 194 | + tokens_len = len(tokens) |
| 195 | + if tokens_len < self._max_len: |
| 196 | + padding = [False] * (self._max_len - tokens_len) |
| 197 | + tokens = tokens + padding |
| 198 | + token_mask = token_mask + padding |
| 199 | + ar_mask = ar_mask + padding |
| 200 | + loss_mask = loss_mask + padding |
| 201 | + else: |
| 202 | + if len(tokens) > self._max_len: |
| 203 | + logging.warning( |
| 204 | + f"Token length ({len(tokens)}) exceeds max length ({self._max_len}), truncating. " |
| 205 | + "Consider increasing the `max_token_len` in your model config if this happens frequently." |
| 206 | + ) |
| 207 | + tokens = tokens[: self._max_len] |
| 208 | + token_mask = token_mask[: self._max_len] |
| 209 | + ar_mask = ar_mask[: self._max_len] |
| 210 | + loss_mask = loss_mask[: self._max_len] |
| 211 | + |
| 212 | + return np.asarray(tokens), np.asarray(token_mask), np.asarray(ar_mask), np.asarray(loss_mask) |
| 213 | + |
| 214 | + def extract_actions(self, tokens: np.ndarray, action_horizon: int, action_dim: int) -> np.ndarray: |
| 215 | + # Decode predicted output tokens |
| 216 | + decoded_tokens = self._paligemma_tokenizer.decode(tokens.tolist()) |
| 217 | + |
| 218 | + # Extract actions from FAST model outputs |
| 219 | + if "Action: " not in decoded_tokens: |
| 220 | + return np.zeros((action_horizon, action_dim), dtype=np.float32) |
| 221 | + |
| 222 | + # Extract actions from decoded tokens |
| 223 | + raw_action_tokens = np.array( |
| 224 | + self._paligemma_tokenizer.encode(decoded_tokens.split("Action: ")[1].split("|")[0].strip()) |
| 225 | + ) |
| 226 | + action_tokens = self._act_tokens_to_paligemma_tokens(raw_action_tokens) |
| 227 | + if len(action_tokens) < action_horizon * action_dim: |
| 228 | + return np.zeros([action_horizon, action_dim], dtype=np.float32) |
| 229 | + action_tokens = action_tokens[: (action_horizon * action_dim)].reshape([action_horizon, action_dim]) |
| 230 | + return action_tokens / self._n_bins * 2 - 1 |
| 231 | + |
| 232 | + def _act_tokens_to_paligemma_tokens(self, tokens: np.ndarray | list[int]) -> np.ndarray: |
| 233 | + if isinstance(tokens, list): |
| 234 | + tokens = np.array(tokens) |
| 235 | + return self._paligemma_tokenizer.vocab_size() - 1 - self._fast_skip_tokens - tokens |
| 236 | + |
| 237 | + |
| 238 | +class FSQTokenizer: |
| 239 | + """ |
| 240 | + FSQ tokenizer from the FAST paper baselines. |
| 241 | + """ |
| 242 | + |
| 243 | + def __init__(self, max_len: int = 256, fsq_tokenizer_path: str | None = None): |
| 244 | + self._max_len = max_len |
| 245 | + |
| 246 | + assert fsq_tokenizer_path is not None, "fsq_tokenizer_path must be provided" |
| 247 | + # Download tokenizer |
| 248 | + path = download.maybe_download(fsq_tokenizer_path) |
| 249 | + tok_path = os.path.join(path, os.listdir(path)[0]) # noqa: PTH118 |
| 250 | + |
| 251 | + # Split step from path |
| 252 | + step = int(tok_path.split("/")[-1]) |
| 253 | + base_path = tok_path.rsplit("/", 1)[0] |
| 254 | + |
| 255 | + mgr = ocp.CheckpointManager( |
| 256 | + base_path, |
| 257 | + item_handlers={ |
| 258 | + "params": ocp.StandardCheckpointHandler(), |
| 259 | + "opt_state": ocp.StandardCheckpointHandler(), |
| 260 | + "config": ocp.JsonCheckpointHandler(), |
| 261 | + }, |
| 262 | + options=ocp.CheckpointManagerOptions(max_to_keep=1), |
| 263 | + ) |
| 264 | + |
| 265 | + try: |
| 266 | + restored = mgr.restore( |
| 267 | + step, args=ocp.args.Composite(config=ocp.args.JsonRestore(), params=ocp.args.StandardRestore()) |
| 268 | + ) |
| 269 | + config = restored["config"] |
| 270 | + self._params = restored["params"] |
| 271 | + self._fsq_tokenizer = fsq_tokenizer.FsqAttentionTokenizer(**config) |
| 272 | + except Exception as e: |
| 273 | + raise RuntimeError( |
| 274 | + f"Failed to load FSQ tokenizer checkpoint from {fsq_tokenizer_path}. Error: {e!s}" |
| 275 | + ) from e |
| 276 | + |
| 277 | + # Compile tokenize and detokenize functions |
| 278 | + self._tokenize_fn = jax.jit( |
| 279 | + lambda params, x: self._fsq_tokenizer.apply({"params": params}, x, method=self._fsq_tokenizer.tokenize) |
| 280 | + ) |
| 281 | + self._detokenize_fn = jax.jit( |
| 282 | + lambda params, x: self._fsq_tokenizer.apply({"params": params}, x, method=self._fsq_tokenizer.detokenize) |
| 283 | + ) |
| 284 | + |
| 285 | + # Download base PaliGemma tokenizer |
| 286 | + path = download.maybe_download("gs://big_vision/paligemma_tokenizer.model", gs={"token": "anon"}) |
| 287 | + with path.open("rb") as f: |
| 288 | + self._paligemma_tokenizer = sentencepiece.SentencePieceProcessor(model_proto=f.read()) |
| 289 | + |
| 290 | + self._fast_skip_tokens = 128 # Skip last 128 tokens in PaliGemma vocab since they are special tokens |
| 291 | + |
| 292 | + def tokenize( |
| 293 | + self, prompt: str, state: np.ndarray, actions: np.ndarray | None |
| 294 | + ) -> tuple[np.ndarray, np.ndarray, np.ndarray, np.ndarray]: |
| 295 | + cleaned_text = prompt.lower().strip().replace("_", " ") |
| 296 | + |
| 297 | + # Convention: state gets discretized into 256 discrete bins (assumed range after normalization: [-1, 1]) |
| 298 | + discretized_state = np.digitize(state, bins=np.linspace(-1, 1, 256 + 1)[:-1]) - 1 |
| 299 | + |
| 300 | + # Convention: prefix includes prompt and string-representation of state, followed by ';' |
| 301 | + state_str = " ".join(map(str, discretized_state)) |
| 302 | + prefix = f"Task: {cleaned_text}, State: {state_str};\n" |
| 303 | + prefix_tokens = self._paligemma_tokenizer.encode(prefix, add_bos=True) |
| 304 | + |
| 305 | + if actions is not None: |
| 306 | + raise NotImplementedError("FSQTokenizer does not support encoding actions atm (only for inference use)") |
| 307 | + postfix_tokens = [] |
| 308 | + |
| 309 | + # Create output token sequence & masks |
| 310 | + # AR mask is 0 on prefix (bidirectional attention) and 1 on postfix (causal attention to all previous tokens) |
| 311 | + tokens = prefix_tokens + postfix_tokens |
| 312 | + token_mask = [True] * len(tokens) |
| 313 | + ar_mask = [0] * len(prefix_tokens) + [1] * len(postfix_tokens) |
| 314 | + loss_mask = [False] * len(prefix_tokens) + [True] * len(postfix_tokens) # Loss on postfix only |
| 315 | + |
| 316 | + # Pad tokens to max length |
| 317 | + tokens_len = len(tokens) |
| 318 | + if tokens_len < self._max_len: |
| 319 | + padding = [False] * (self._max_len - tokens_len) |
| 320 | + tokens = tokens + padding |
| 321 | + token_mask = token_mask + padding |
| 322 | + ar_mask = ar_mask + padding |
| 323 | + loss_mask = loss_mask + padding |
| 324 | + else: |
| 325 | + if len(tokens) > self._max_len: |
| 326 | + logging.warning( |
| 327 | + f"Token length ({len(tokens)}) exceeds max length ({self._max_len}), truncating. " |
| 328 | + "Consider increasing the `max_token_len` in your model config if this happens frequently." |
| 329 | + ) |
| 330 | + tokens = tokens[: self._max_len] |
| 331 | + token_mask = token_mask[: self._max_len] |
| 332 | + ar_mask = ar_mask[: self._max_len] |
| 333 | + loss_mask = loss_mask[: self._max_len] |
| 334 | + |
| 335 | + return np.asarray(tokens), np.asarray(token_mask), np.asarray(ar_mask), np.asarray(loss_mask) |
| 336 | + |
| 337 | + def extract_actions(self, tokens: np.ndarray, action_horizon: int, action_dim: int) -> np.ndarray: |
| 338 | + # Decode predicted output tokens |
| 339 | + decoded_tokens = self._paligemma_tokenizer.decode(tokens.tolist()) |
| 340 | + |
| 341 | + # Extract actions from FAST model outputs |
| 342 | + if "Action: " not in decoded_tokens: |
| 343 | + return np.zeros((action_horizon, action_dim), dtype=np.float32) |
| 344 | + |
| 345 | + # Extract actions from decoded tokens |
| 346 | + raw_action_tokens = np.array( |
| 347 | + self._paligemma_tokenizer.encode(decoded_tokens.split("Action: ")[1].split("|")[0].strip()) |
| 348 | + ) |
| 349 | + action_tokens = self._act_tokens_to_paligemma_tokens(raw_action_tokens) |
| 350 | + try: |
| 351 | + # Move computation to CPU and compile on-demand |
| 352 | + device = jax.devices("cpu")[0] |
| 353 | + with jax.default_device(device): |
| 354 | + detok_act = self._detokenize_fn(self._params, action_tokens[None, ...])[0] |
| 355 | + return detok_act[: action_horizon * action_dim].reshape([action_horizon, action_dim]) |
| 356 | + except Exception as e: |
| 357 | + logging.warning(f"Error decoding FSQ: {e}") |
| 358 | + return np.zeros((action_horizon, action_dim)) |
| 359 | + |
| 360 | + def _act_tokens_to_paligemma_tokens(self, tokens: np.ndarray | list[int]) -> np.ndarray: |
| 361 | + if isinstance(tokens, list): |
| 362 | + tokens = np.array(tokens) |
| 363 | + return self._paligemma_tokenizer.vocab_size() - 1 - self._fast_skip_tokens - tokens |
0 commit comments