Skip to content

Commit 8734c41

Browse files
Merge pull request #25 from Basile-Terv/main
fix rollout loss term of V-JEPA 2-AC
2 parents 1a1f68d + 93366fe commit 8734c41

File tree

1 file changed

+1
-1
lines changed

1 file changed

+1
-1
lines changed

app/vjepa_droid/train.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -427,7 +427,7 @@ def _step_predictor(_z, _a, _s, _e):
427427
z_tf = _step_predictor(_z, _a, _s, _e)
428428

429429
# -- full auto-regressive rollouts of predictor
430-
_z = torch.cat([z[:, :tokens_per_frame], z_tf[:, tokens_per_frame : 2 * tokens_per_frame]], dim=1)
430+
_z = torch.cat([z[:, : tokens_per_frame], z_tf[:, : tokens_per_frame]], dim=1)
431431
for n in range(1, auto_steps):
432432
_a, _s, _e = actions[:, : n + 1], states[:, : n + 1], extrinsics[:, : n + 1]
433433
_z_nxt = _step_predictor(_z, _a, _s, _e)[:, -tokens_per_frame:]

0 commit comments

Comments
 (0)