Skip to content

Commit db3374b

Browse files
kpertschKarl Pertsch
authored andcommitted
Apply suggestions from code review
1 parent d5501c2 commit db3374b

File tree

5 files changed

+183
-131
lines changed

5 files changed

+183
-131
lines changed

examples/droid/README.md

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -66,3 +66,5 @@ uv run scripts/serve_policy.py policy:checkpoint --policy.config=paligemma_vq_dr
6666
# pi0-style diffusion / flow VLA, trained on DROID from PaliGemma.
6767
uv run scripts/serve_policy.py policy:checkpoint --policy.config=paligemma_diffusion_droid --policy.dir=gs://openpi-assets/checkpoints/roboarena/paligemma_diffusion_droid
6868
```
69+
70+
You can find the inference configs in [roboarena_config.py](../../src/openpi/training/misc/roboarena_config.py).

src/openpi/models/tokenizer.py

Lines changed: 13 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -156,6 +156,19 @@ def __init__(self, max_len: int = 256, n_bins: int = 256):
156156
def tokenize(
157157
self, prompt: str, state: np.ndarray, actions: np.ndarray | None
158158
) -> tuple[np.ndarray, np.ndarray, np.ndarray, np.ndarray]:
159+
"""Tokenize a prompt and state into a sequence of tokens.
160+
161+
Args:
162+
prompt: The text prompt to tokenize.
163+
state: The state array to discretize and tokenize.
164+
actions: Must be None. Action encoding is not currently supported.
165+
166+
Returns:
167+
A tuple of (tokens, token_mask, ar_mask, targets).
168+
169+
Raises:
170+
NotImplementedError: If actions is not None.
171+
"""
159172
cleaned_text = prompt.lower().strip().replace("_", " ")
160173

161174
# Convention: state gets discretized into 256 discrete bins (assumed range after normalization: [-1, 1])

src/openpi/models/utils/fsq_tokenizer.py

Lines changed: 47 additions & 38 deletions
Original file line numberDiff line numberDiff line change
@@ -1,5 +1,5 @@
11
import math
2-
from typing import Literal
2+
from typing import Any, Literal
33

44
import chex
55
from einops import einops
@@ -20,7 +20,7 @@ class FsqCodebook(nn.Module):
2020
_bins_per_dim: tuple[int] | None = None
2121

2222
@property
23-
def bins_per_dim(self):
23+
def bins_per_dim(self) -> tuple[int]:
2424
if self._bins_per_dim is not None:
2525
return self._bins_per_dim
2626

@@ -34,14 +34,14 @@ def bins_per_dim(self):
3434
raise ValueError(f"Codebook type {self.codebook_type} not supported.")
3535

3636
@property
37-
def place_values(self):
37+
def place_values(self) -> jnp.ndarray:
3838
place_values = [1]
3939
for b in self.bins_per_dim[:-1]:
4040
place_values.append(place_values[-1] * b)
4141
return jnp.array(place_values)
4242

4343
@staticmethod
44-
def _get_bins_fsq(target_codebook_size):
44+
def _get_bins_fsq(target_codebook_size: int) -> tuple[int]:
4545
"""
4646
Get bins per dimension based on codebook size, from the original FSQ paper.
4747
"""
@@ -59,7 +59,7 @@ def _get_bins_fsq(target_codebook_size):
5959
raise ValueError(f"Codebook size {target_codebook_size} not supported.")
6060

6161
@staticmethod
62-
def _get_bins_custom(target_codebook_size):
62+
def _get_bins_custom(target_codebook_size: int) -> tuple[int]:
6363
if target_codebook_size == 2**8:
6464
return (16, 16)
6565
elif target_codebook_size == 2**10: # noqa: RET505
@@ -73,7 +73,7 @@ def _get_bins_custom(target_codebook_size):
7373
return None
7474

7575
@staticmethod
76-
def _get_bins_lfq(target_codebook_size):
76+
def _get_bins_lfq(target_codebook_size: int) -> tuple[int]:
7777
"""
7878
Get bins per dimension according to the Lookup-Free Quantization paper (2 bins per dimension)
7979
"""
@@ -85,12 +85,12 @@ def setup(self):
8585
self.proj_down = nn.Dense(len(self.bins_per_dim))
8686
self.proj_up = nn.Dense(self.input_dim)
8787

88-
def __call__(self, inputs):
88+
def __call__(self, inputs: jnp.ndarray) -> tuple[jnp.ndarray, jnp.ndarray]:
8989
tokens, z = self.encode(inputs)
9090
output = self.decode(tokens, z_grad=z)
9191
return tokens, output
9292

93-
def encode(self, inputs):
93+
def encode(self, inputs: jnp.ndarray) -> tuple[jnp.ndarray, jnp.ndarray]:
9494
bases = jnp.array(self.bins_per_dim)
9595

9696
x = self.proj_down(inputs)
@@ -102,7 +102,7 @@ def encode(self, inputs):
102102

103103
return tokens, z
104104

105-
def decode(self, tokens, z_grad: jax.Array | None = None):
105+
def decode(self, tokens: jnp.ndarray, z_grad: jax.Array | None = None) -> jnp.ndarray:
106106
bases = jnp.array(self.bins_per_dim)
107107
digits = self.digitize(tokens)
108108

@@ -114,14 +114,14 @@ def decode(self, tokens, z_grad: jax.Array | None = None):
114114

115115
return self.proj_up(z_q)
116116

117-
def undigitize(self, digits):
117+
def undigitize(self, digits: jnp.ndarray) -> jnp.ndarray:
118118
return jnp.sum(digits * jnp.array(self.place_values), axis=-1)
119119

120-
def digitize(self, tokens):
120+
def digitize(self, tokens: jnp.ndarray) -> jnp.ndarray:
121121
return (tokens[..., None] // jnp.array(self.place_values)) % jnp.array(self.bins_per_dim)
122122

123123
@property
124-
def vocab_size(self):
124+
def vocab_size(self) -> int:
125125
return math.prod(self.bins_per_dim)
126126

127127

@@ -132,7 +132,7 @@ class ResNetDownBlock(nn.Module):
132132
group_size: int = 32
133133

134134
@nn.compact
135-
def __call__(self, x, *, train=True):
135+
def __call__(self, x: jnp.ndarray, *, train: bool = True) -> jnp.ndarray:
136136
skip = x
137137

138138
if self.stride > 1 or x.shape[-1] != self.n_filters:
@@ -154,7 +154,7 @@ class ResNetUpBlock(nn.Module):
154154
group_size: int = 32
155155

156156
@nn.compact
157-
def __call__(self, x, *, train=True):
157+
def __call__(self, x: jnp.ndarray, *, train: bool = True) -> jnp.ndarray:
158158
skip = x
159159

160160
if self.stride > 1:
@@ -184,30 +184,29 @@ class LookupFreeQuantization(nn.Module):
184184

185185
def setup(self):
186186
self.codebook = jnp.array([-1, 1])
187-
# self.activation = lambda x: x
188187
self.activation = nn.tanh
189188

190189
self.project_down = nn.Dense(self.num_dims)
191190
self.project_up = nn.Dense(self.latent_dim)
192191

193-
def encode(self, z):
192+
def encode(self, z: jnp.ndarray) -> jnp.ndarray:
194193
z = self.project_down(z)
195194
token_squared_distances = jnp.square(z[..., None] - self.codebook)
196195
token_bits = jnp.argmin(token_squared_distances, axis=-1)
197196
return jnp.sum(token_bits * (2 ** jnp.arange(self.num_dims)), axis=-1)
198197

199-
def decode(self, tokens):
198+
def decode(self, tokens: jnp.ndarray) -> jnp.ndarray:
200199
token_bits = (tokens[..., None] & (2 ** jnp.arange(self.num_dims))).astype(jnp.int32)
201200
return self.project_up(self.codebook[token_bits])
202201

203-
def loss(self, x):
202+
def loss(self, x: jnp.ndarray) -> LfqCodebookOutput:
204203
z = self.project_down(x)
205204
z = self.activation(z)
206205

207206
token_squared_distances = jnp.square(z[..., None] - self.codebook)
208207
tokens = jnp.argmin(token_squared_distances, axis=-1)
209208

210-
token_bit_log_probs = -token_squared_distances # jax.nn.log_softmax(-token_squared_distances, axis=-1)
209+
token_bit_log_probs = -token_squared_distances
211210
# Compute token log probs for tokens 0..2^num_dims-1 by summing corresponding log-probs
212211
token_bit_expansions = jnp.bitwise_and(
213212
jnp.arange(2**self.num_dims)[None, :], 2 ** jnp.arange(self.num_dims)[:, None]
@@ -236,7 +235,7 @@ def loss(self, x):
236235
)
237236

238237

239-
def make_block_causal_attention_matrix(q, k, bs_q, bs_k):
238+
def make_block_causal_attention_matrix(q: jnp.ndarray, k: jnp.ndarray, bs_q: int, bs_k: int) -> jnp.ndarray:
240239
return nn.make_attention_mask(q, k, pairwise_fn=lambda x, y: jnp.greater_equal(x // bs_k, y // bs_q))
241240

242241

@@ -245,14 +244,7 @@ class GeGLU(Module):
245244
GeGLU is a Flax layer that combines a linear transformation with a GELU
246245
activation function in a gating mechanism. It is often used in Transformer models
247246
to provide non-linear capabilities while preserving a strong linear component.
248-
Example usage::
249-
>>> import flax.linen as nn
250-
>>> class TransformerBlock(nn.Module):
251-
... @nn.compact
252-
... def __call__(self, x):
253-
... x = nn.Dense(2)(x)
254-
... x = nn.GeGLU()(x) # initialized
255-
... return x
247+
256248
Attributes:
257249
features: the number of output features (default: None).
258250
"""
@@ -281,7 +273,15 @@ class CrossAttentionLayer(nn.Module):
281273
mlp_ratio: float = 4.0
282274

283275
@nn.compact
284-
def __call__(self, x, y, *, mask_self=None, mask_cross=None, train=True):
276+
def __call__(
277+
self,
278+
x: jnp.ndarray,
279+
y: jnp.ndarray,
280+
*,
281+
mask_self: jnp.ndarray | None = None,
282+
mask_cross: jnp.ndarray | None = None,
283+
train: bool = True,
284+
) -> jnp.ndarray:
285285
d_embed = x.shape[-1]
286286
seq_len_q = x.shape[-2]
287287
seq_len_k = y.shape[-2]
@@ -307,12 +307,10 @@ def __call__(self, x, y, *, mask_self=None, mask_cross=None, train=True):
307307
# Cross-attention block
308308
skip = x
309309
x = nn.LayerNorm()(x)
310-
# bias = -jnp.abs(jnp.linspace(0, 1, seq_len_q)[:, None] - jnp.linspace(0, 1, seq_len_k)) * 5
311310
x = nn.MultiHeadDotProductAttention(
312311
num_heads=self.num_heads or d_embed // 64,
313312
dropout_rate=self.dropout_rate,
314313
deterministic=not train,
315-
# attention_fn=partial(nn.dot_product_attention, bias=bias),
316314
)(x, y, y, mask=mask_cross)
317315
x = skip + x
318316

@@ -326,7 +324,7 @@ def __call__(self, x, y, *, mask_self=None, mask_cross=None, train=True):
326324
return skip + x
327325

328326

329-
def sinusoidal_pe_init(_, shape):
327+
def sinusoidal_pe_init(_, shape: tuple[int, int]) -> jnp.ndarray:
330328
seq_len, d_embed = shape
331329

332330
position = jnp.arange(0, seq_len, 1)
@@ -350,7 +348,14 @@ class TokenizerEncoderDecoder(nn.Module):
350348
use_state_conditioning: bool = False
351349

352350
@nn.compact
353-
def __call__(self, y, *, train=True, state_conditioning=None, mask=None):
351+
def __call__(
352+
self,
353+
y: jnp.ndarray,
354+
*,
355+
train: bool = True,
356+
state_conditioning: jnp.ndarray | None = None,
357+
mask: jnp.ndarray | None = None,
358+
) -> jnp.ndarray:
354359
x = self.param("q_embed", sinusoidal_pe_init, (self.num_tokens, y.shape[-1]))
355360
x = jax.numpy.broadcast_to(x, y.shape[:-2] + x.shape[-2:])
356361

@@ -392,7 +397,7 @@ class FsqAttentionTokenizer(nn.Module):
392397
use_state_conditioning: bool = False
393398

394399
@property
395-
def vocab_size(self):
400+
def vocab_size(self) -> int:
396401
return math.prod(FsqCodebook._get_bins_fsq(self.target_codebook_size)) # noqa: SLF001
397402

398403
def setup(self):
@@ -422,7 +427,9 @@ def setup(self):
422427
self.proj_mean = nn.Dense(self.data_dim)
423428
self.out_scale = self.param("out_scale", lambda _: jnp.full((), 1.0))
424429

425-
def tokenize(self, action, *, obs=None, train=False):
430+
def tokenize(
431+
self, action: jnp.ndarray, *, obs: jnp.ndarray | None = None, train: bool = False
432+
) -> tuple[jnp.ndarray, jnp.ndarray]:
426433
if self.bound is not None:
427434
action = jnp.clip(action, -self.bound, self.bound)
428435

@@ -431,12 +438,14 @@ def tokenize(self, action, *, obs=None, train=False):
431438

432439
return self.codebook.encode(x)
433440

434-
def detokenize(self, tokens, *, obs=None):
441+
def detokenize(self, tokens: jnp.ndarray, *, obs: jnp.ndarray | None = None) -> jnp.ndarray:
435442
x = self.decoder(self.codebook.decode(tokens), state_conditioning=obs)
436443
mean = self.proj_mean(x)
437444
return mean * self.out_scale
438445

439-
def loss(self, action, *, obs=None, train=True):
446+
def loss(
447+
self, action: jnp.ndarray, *, obs: jnp.ndarray | None = None, train: bool = True
448+
) -> tuple[jnp.ndarray, dict[str, jnp.ndarray]]:
440449
# Encode
441450
x = self.proj(action)
442451
z = self.encoder(x, train=train, state_conditioning=obs)
@@ -456,7 +465,7 @@ def loss(self, action, *, obs=None, train=True):
456465
"mae": mae,
457466
}
458467

459-
def __call__(self, *args, **kwargs):
468+
def __call__(self, *args: Any, **kwargs: Any) -> tuple[jnp.ndarray, dict[str, jnp.ndarray]]:
460469
"""
461470
Dummy for .init
462471
"""

0 commit comments

Comments
 (0)