Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
Show all changes
233 commits
Select commit Hold shift + click to select a range
1303a69
Adapt to Our Datasets (#1)
DDVD233 Aug 11, 2025
a824ed5
Experimental: Add support for audio training
DDVD233 Aug 11, 2025
30e1e06
Debug for audios
DDVD233 Aug 14, 2025
732e70b
Debug for audios
DDVD233 Aug 14, 2025
3892b05
Add omni support
DDVD233 Aug 14, 2025
364713f
Add omni support
DDVD233 Aug 14, 2025
954a52d
Add omni support
DDVD233 Aug 14, 2025
e6ee543
Add omni support
DDVD233 Aug 14, 2025
313846a
Use torchaudio
DDVD233 Aug 14, 2025
fbe5a7f
Debug for audio
DDVD233 Aug 15, 2025
704a666
Debug for audio
DDVD233 Aug 15, 2025
ee0f78f
Debug for audio
DDVD233 Aug 15, 2025
3cb5165
Debug for audio
DDVD233 Aug 15, 2025
902007c
Update prompt
DDVD233 Aug 15, 2025
f976def
debug
DDVD233 Aug 15, 2025
95735c8
debug
DDVD233 Aug 15, 2025
2fdd078
debug
DDVD233 Aug 15, 2025
461e995
debug
DDVD233 Aug 15, 2025
4b6ee75
debug
DDVD233 Aug 15, 2025
551fde3
debug
DDVD233 Aug 15, 2025
c3a0f66
debug
DDVD233 Aug 15, 2025
3bae315
debug
DDVD233 Aug 15, 2025
2552705
debug
DDVD233 Aug 15, 2025
27b0dee
debug
DDVD233 Aug 15, 2025
78b97e3
debug
DDVD233 Aug 15, 2025
ada904d
debug
DDVD233 Aug 15, 2025
b0162af
debug
DDVD233 Aug 15, 2025
efbab94
debug
DDVD233 Aug 15, 2025
6985493
Debug
DDVD233 Aug 15, 2025
0f94e3b
Reduce batch size / remove kl
DDVD233 Aug 15, 2025
092380d
_
keanepotato Aug 15, 2025
af235c0
_
keanepotato Aug 15, 2025
9d4267e
_
keanepotato Aug 15, 2025
f7d8207
_
keanepotato Aug 15, 2025
231440e
_
keanepotato Aug 15, 2025
9b01ddd
_
keanepotato Aug 15, 2025
d548642
_
keanepotato Aug 16, 2025
1e46e85
_
keanepotato Aug 16, 2025
4b8098f
_
keanepotato Aug 16, 2025
e6a780e
_
keanepotato Aug 16, 2025
e737994
_
keanepotato Aug 16, 2025
fd8ab0e
push req txt
keanepotato Aug 16, 2025
64de449
push req txt
keanepotato Aug 16, 2025
ebf10b0
push req txt
keanepotato Aug 16, 2025
144ed8b
push req txt
keanepotato Aug 16, 2025
c2d1c8b
push req txt
keanepotato Aug 16, 2025
ca293ef
_
keanepotato Aug 16, 2025
d1a01c5
_
keanepotato Aug 16, 2025
0fc0184
_
keanepotato Aug 16, 2025
66f0b91
_
keanepotato Aug 16, 2025
ac3ebbf
_
keanepotato Aug 16, 2025
0a7f10a
_
keanepotato Aug 16, 2025
3319c5b
_
keanepotato Aug 16, 2025
192e705
_
keanepotato Aug 16, 2025
4b1f492
_
keanepotato Aug 16, 2025
47fecfb
_
keanepotato Aug 16, 2025
7ac1aba
_
keanepotato Aug 16, 2025
aab4d05
_
keanepotato Aug 16, 2025
16f65b9
_
keanepotato Aug 16, 2025
1d6fb48
_
keanepotato Aug 16, 2025
2e414eb
_
keanepotato Aug 16, 2025
a510d2d
_
keanepotato Aug 16, 2025
7cb9233
_
keanepotato Aug 16, 2025
d5d29cc
_
keanepotato Aug 16, 2025
184ee8c
_
keanepotato Aug 16, 2025
582390d
_
keanepotato Aug 16, 2025
62bab5f
_
keanepotato Aug 16, 2025
47f8f3c
_
keanepotato Aug 17, 2025
88d3d26
update
keanepotato Aug 17, 2025
d9cab8b
_
keanepotato Aug 17, 2025
292893d
_
keanepotato Aug 17, 2025
a5b50e6
_
keanepotato Aug 17, 2025
db61a03
_
keanepotato Aug 17, 2025
87ec13a
_
keanepotato Aug 17, 2025
c3f28f4
_
keanepotato Aug 17, 2025
16caf36
_
keanepotato Aug 17, 2025
741952f
_
keanepotato Aug 17, 2025
8e51860
_
keanepotato Aug 17, 2025
11bab30
_
keanepotato Aug 17, 2025
ad371ed
_
keanepotato Aug 17, 2025
335e4c5
_
keanepotato Aug 17, 2025
b339f9f
_
keanepotato Aug 17, 2025
fe0e2e2
_
keanepotato Aug 17, 2025
761535b
_
keanepotato Aug 17, 2025
513b556
_
keanepotato Aug 17, 2025
c3ac743
_
keanepotato Aug 17, 2025
4d74680
_
keanepotato Aug 17, 2025
81ee854
_
keanepotato Aug 17, 2025
4a133a5
fix audio
keanepotato Aug 17, 2025
5e17e58
attempt fix two
keanepotato Aug 17, 2025
da571b0
_
keanepotato Aug 17, 2025
0a86fe7
_
keanepotato Aug 18, 2025
5f43d91
_
keanepotato Aug 18, 2025
6250fae
_
keanepotato Aug 18, 2025
5a720ad
implement reward func
keanepotato Aug 18, 2025
e279509
_
keanepotato Aug 18, 2025
0b006bc
_
keanepotato Aug 18, 2025
625cf79
_
keanepotato Aug 18, 2025
d9a4d19
_
keanepotato Aug 18, 2025
4a0f56b
_
keanepotato Aug 18, 2025
e626785
_
keanepotato Aug 18, 2025
057a377
_
keanepotato Aug 18, 2025
2c6b5ab
_
keanepotato Aug 18, 2025
33e3eab
_
keanepotato Aug 18, 2025
8f9c2d5
_
keanepotato Aug 18, 2025
e93952a
_
keanepotato Aug 18, 2025
69f2d40
_
keanepotato Aug 18, 2025
c7938d6
_
keanepotato Aug 18, 2025
bcd4939
_
keanepotato Aug 18, 2025
91b44d8
_
keanepotato Aug 18, 2025
fff7e90
_
keanepotato Aug 18, 2025
818754e
_
keanepotato Aug 18, 2025
9d9be83
_
keanepotato Aug 18, 2025
9a95d7e
_
keanepotato Aug 18, 2025
f6c4a54
_
keanepotato Aug 18, 2025
741b4e7
_
keanepotato Aug 18, 2025
80568aa
_
keanepotato Aug 18, 2025
ed04487
_
keanepotato Aug 18, 2025
d417dc2
_
keanepotato Aug 18, 2025
5258d5b
_
keanepotato Aug 18, 2025
1424a55
_
keanepotato Aug 18, 2025
c19287c
update schema
keanepotato Aug 18, 2025
e2ad39c
update schema
keanepotato Aug 18, 2025
c1ff427
_
keanepotato Aug 18, 2025
f597bf8
sync
keanepotato Aug 18, 2025
740a764
_
keanepotato Aug 18, 2025
effc3aa
_
keanepotato Aug 18, 2025
0df474a
_
keanepotato Aug 18, 2025
d892c6c
_
keanepotato Aug 18, 2025
267345d
_
keanepotato Aug 18, 2025
d9f547a
_
keanepotato Aug 18, 2025
822e7b5
push command
keanepotato Aug 18, 2025
6702da7
_
keanepotato Aug 18, 2025
46e3bf8
_
keanepotato Aug 18, 2025
43fce91
remove prints
keanepotato Aug 18, 2025
620640b
implement reward score
keanepotato Aug 18, 2025
cd5b9bd
_
keanepotato Aug 18, 2025
066345c
update hb reward
keanepotato Aug 18, 2025
d2103a6
_
keanepotato Aug 18, 2025
3024ddb
_
keanepotato Aug 18, 2025
d910db3
workers
keanepotato Aug 18, 2025
325029a
_
keanepotato Aug 18, 2025
5b519f1
_
keanepotato Aug 18, 2025
ece6cfe
_
keanepotato Aug 18, 2025
0efaa2a
_
keanepotato Aug 18, 2025
863c4d7
_
keanepotato Aug 18, 2025
3f5626f
_
keanepotato Aug 18, 2025
8e5896b
_
keanepotato Aug 18, 2025
e5acd60
_
keanepotato Aug 18, 2025
d1cd282
change pixels
keanepotato Aug 18, 2025
7586747
change pixels
keanepotato Aug 19, 2025
ed6c4e0
_
keanepotato Aug 19, 2025
e9b985b
try omni pixel length
keanepotato Aug 19, 2025
fbf6ace
_
keanepotato Aug 19, 2025
384e237
_
keanepotato Aug 19, 2025
ea93f8c
_
keanepotato Aug 19, 2025
bb3621d
set omni pixels
keanepotato Aug 19, 2025
cc605fe
_
keanepotato Aug 19, 2025
fb9bdd5
_
keanepotato Aug 19, 2025
17198e1
flash attn fsdp
keanepotato Aug 19, 2025
3416084
_
keanepotato Aug 19, 2025
f378c70
_
keanepotato Aug 19, 2025
ceaf4f8
_
keanepotato Aug 19, 2025
6c50032
set max model len
keanepotato Aug 19, 2025
362ac3d
_
keanepotato Aug 19, 2025
f300071
set rollouts
keanepotato Aug 19, 2025
6185704
_
keanepotato Aug 19, 2025
4bd9a5c
_
keanepotato Aug 19, 2025
aa845fb
_
keanepotato Aug 19, 2025
08e671a
clip audio
keanepotato Aug 19, 2025
dc13cbe
clip audio
keanepotato Aug 19, 2025
ff1cbfc
_
keanepotato Aug 19, 2025
4a216fe
_
keanepotato Aug 19, 2025
faa3fa4
debug
keanepotato Aug 19, 2025
dc9ec7f
debug catch
keanepotato Aug 19, 2025
e29be99
debug catch
keanepotato Aug 19, 2025
cca8867
_
keanepotato Aug 19, 2025
0424575
_
keanepotato Aug 19, 2025
87e55cf
_
keanepotato Aug 19, 2025
cd320d8
push audio
keanepotato Aug 19, 2025
ac97461
_
keanepotato Aug 19, 2025
56883a2
_
keanepotato Aug 19, 2025
fdfdb40
_
keanepotato Aug 19, 2025
6dbcb35
_
keanepotato Aug 19, 2025
98c20e3
_
keanepotato Aug 19, 2025
c539637
_
keanepotato Aug 19, 2025
23955ae
_
keanepotato Aug 19, 2025
a5b41c9
_
keanepotato Aug 19, 2025
237e21e
_
keanepotato Aug 19, 2025
21aba8b
_
keanepotato Aug 19, 2025
02fb620
_
keanepotato Aug 19, 2025
b7468c3
_
keanepotato Aug 19, 2025
7f41420
revert commit
keanepotato Aug 19, 2025
1984a6b
debug print
keanepotato Aug 19, 2025
1efce59
debug print
keanepotato Aug 19, 2025
55aaa26
debug print
keanepotato Aug 19, 2025
6d35542
debug
keanepotato Aug 19, 2025
b02313d
_
keanepotato Aug 19, 2025
10265a9
_
keanepotato Aug 19, 2025
24b6d61
_
keanepotato Aug 19, 2025
d598b57
prep video training
keanepotato Aug 19, 2025
4ec901b
prep audio config
keanepotato Aug 19, 2025
137a73e
_
keanepotato Aug 19, 2025
9c36729
_
keanepotato Aug 20, 2025
f670372
_
keanepotato Aug 20, 2025
7001ac3
_
keanepotato Aug 20, 2025
eaa2f40
_
keanepotato Aug 20, 2025
c2e0c3d
_
keanepotato Aug 20, 2025
55d108e
_
keanepotato Aug 20, 2025
650550e
_
keanepotato Aug 20, 2025
b0e7a9b
debug_off
keanepotato Aug 20, 2025
3a73517
_
keanepotato Aug 20, 2025
7d11247
_
keanepotato Aug 20, 2025
d68f375
_
keanepotato Aug 20, 2025
eca7a4a
_
keanepotato Aug 20, 2025
8ecf257
_
keanepotato Aug 20, 2025
8dfbd2a
_
keanepotato Aug 20, 2025
ee9d04c
_
keanepotato Aug 20, 2025
bbfca47
push modality sampler
keanepotato Aug 20, 2025
68a07b3
push modality sampler
keanepotato Aug 20, 2025
620f56d
add debug
keanepotato Aug 20, 2025
1e10146
_
keanepotato Aug 20, 2025
0d5d23f
add to config
keanepotato Aug 20, 2025
9ceabdc
add features
keanepotato Aug 20, 2025
9e21b46
touch up dataloader
keanepotato Aug 20, 2025
564b966
debug write
keanepotato Aug 20, 2025
3c74a53
debug write
keanepotato Aug 20, 2025
67b9387
debug counterfactual
keanepotato Aug 20, 2025
bd9ca3f
debug counterfact
keanepotato Aug 20, 2025
9465b06
_
keanepotato Aug 20, 2025
772a471
_
keanepotato Aug 20, 2025
6a6f11f
_
keanepotato Aug 20, 2025
682673d
Merge pull request #3 from DDVD233/keane
keanepotato Aug 20, 2025
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
308 changes: 308 additions & 0 deletions _unit_test_modality_sampler.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,308 @@
# test_stateful_modality_sampler_hardcoded.py

import json
from typing import Dict, Any, List, Iterator
from torch.utils.data import Dataset, BatchSampler
import random

# ==== ADJUST PATHS below to match your repo structure ====
# from verl.utils.dataset.modality_sampler import ModalitySignatureBatchSampler
from torchdata.stateful_dataloader import StatefulDataLoader
from collections import defaultdict, deque
import torch
import numpy as np

# ---------- HARD-CODED PATHS + CONFIG ----------
JSONL_PATH = "/Users/keane/Desktop/research/human-behavior/data/all/sigs_no_lmvd_discretized_v3_template_prompts.jsonl"
TRAIN_BS = 4
VAL_BS = 4
SEED = 42
TRUNCATE_RATIO = 0.001 # for quick testing; set to 1.0 to disable
# ---------------------------------------------

# TODO: Please remove text only; everything should be text_only

class ModalitySignatureBatchSampler(BatchSampler):
"""
Round-robin across modality signatures, pruning exhausted signatures.
- Shuffles within each signature if shuffle=True (train).
- Each yielded batch is homogeneous by modality_signature.
- If a signature runs out of batches, it is removed and RR continues.
"""
def __init__(
self,
indices_by_sig: Dict[str, List[int]],
batch_size: int,
drop_last: bool = True,
seed: int = 42,
shuffle: bool = True,
):
self.indices_by_sig = {s: list(v) for s, v in indices_by_sig.items()}
self.batch_size = int(batch_size)
self.drop_last = drop_last
self.shuffle = shuffle
self.rng = random.Random(seed)
self.sigs = list(self.indices_by_sig.keys())

def _batches_for(self, pool: List[int]) -> List[List[int]]:
n = len(pool)
batches = []
for start in range(0, n, self.batch_size):
chunk = pool[start:start + self.batch_size]
if len(chunk) < self.batch_size and self.drop_last:
continue
if chunk:
batches.append(chunk)
return batches

def __iter__(self) -> Iterator[List[int]]:
# Fresh pools + optional shuffle within each signature
pools = {s: list(v) for s, v in self.indices_by_sig.items()}
for s in pools:
if self.shuffle:
self.rng.shuffle(pools[s])

# Build per-signature batch queues; essentially a dictionary with batches of each different modality signature
per_sig_batches = {s: deque(self._batches_for(pools[s])) for s in self.sigs}

# Establish RR order
order = list(self.sigs)
if self.shuffle:
# rotate start signature per epoch for variety (keeps RR structure)
k = self.rng.randrange(len(order)) if order else 0
order = order[k:] + order[:k]
else:
order = sorted(order)

# Active signatures as a deque for easy rotation
active = deque([s for s in order if len(per_sig_batches[s]) > 0])

while active:
s = active.popleft() # take the queue's leftmost element (modality signature)
q = per_sig_batches[s] # access all of the batched stuff
if q:
yield q.popleft() # yield that batch
# if still has batches, push to the end to continue RR
if q:
active.append(s) # reappend the modality signature to the active queue
# if q is empty, we simply don't re-append s → pruned automatically
else:
print(f"Ran-Out: Pruning modality signature: {s}")

def __len__(self) -> int:
# Total number of batches across all signatures (after drop_last handling)
total = 0
for pool in self.indices_by_sig.values():
full, rem = divmod(len(pool), self.batch_size)
total += full + (0 if self.drop_last or rem == 0 else 1)
return total


def rl_collate_fn(data_list: list[dict]) -> dict:
"""
Collate a batch of sample dicts into batched tensors and arrays.

Args:
data_list: List of dicts mapping feature names to torch.Tensor or other values.

Returns:
Dict where tensor entries are stacked into a torch.Tensor of shape
(batch_size, dims) and non-tensor entries are converted to
np.ndarray of dtype object with shape (batch_size,).
"""
tensors = defaultdict(list)
non_tensors = defaultdict(list)

for data in data_list:
for key, val in data.items():
if isinstance(val, torch.Tensor):
tensors[key].append(val)
else:
non_tensors[key].append(val)

for key, val in tensors.items():
tensors[key] = torch.stack(val, dim=0)

for key, val in non_tensors.items():
non_tensors[key] = np.fromiter(val, dtype=object, count=len(val))

return {**tensors, **non_tensors}

def create_rl_sampler(data_config, dataset, split: str = "train"):
"""Create a sampler for the dataset, grouping strictly by existing modality_signature."""
import torch
from torch.utils.data import RandomSampler, SequentialSampler

mb_cfg = data_config.get("modality_batching") if split == "train" \
else data_config.get("val_modality_batching")

# (keep curriculum path if you actually use it; omitted here for brevity)

if mb_cfg and mb_cfg.get("enabled", False):
by_sig: Dict[str, List[int]] = {}
for i in range(len(dataset)):
row = dataset.dataframe[i] if hasattr(dataset, "dataframe") else dataset[i]
sig = row.get("modality_signature")
if sig is None:
print(f"[WARNING] Row {i} missing 'modality_signature'. Skipping.")
continue
by_sig.setdefault(sig, []).append(i)

batch_size = mb_cfg.get("batch_size", data_config.get(
"train_batch_size" if split=="train" else "val_batch_size"
))
drop_last = mb_cfg.get("drop_last", split=="train")
shuffle = (split == "train")

return ModalitySignatureBatchSampler(
indices_by_sig=by_sig,
batch_size=int(batch_size),
drop_last=drop_last,
seed=data_config.get("seed", 42),
shuffle=shuffle,
)

# Fallbacks
if data_config.get("shuffle", True) and split == "train":
g = torch.Generator(); g.manual_seed(data_config.get("seed", 1))
return RandomSampler(data_source=dataset, generator=g)
else:
return SequentialSampler(data_source=dataset)

class JsonlDataset(Dataset):
def __init__(self, jsonl_path: str, truncate_ratio: float = TRUNCATE_RATIO, seed: int = SEED):
"""
Loads ONLY entries that already have 'modality_signature'.
Optionally keeps a proportion per signature for fast debugging.
"""
all_rows: List[Dict[str, Any]] = []
with open(jsonl_path, "r", encoding="utf-8") as f:
for ln in f:
ln = ln.strip()
if not ln:
continue
ex = json.loads(ln)
sig = ex.get("modality_signature")
if sig is None:
print(f"[WARNING] Entry missing 'modality_signature'. Skipping.")
continue # skip missing
all_rows.append(ex)

# Group by signature and truncate per signature
sig_to_rows: Dict[str, List[Dict[str, Any]]] = defaultdict(list)
for ex in all_rows:
sig_to_rows[ex["modality_signature"]].append(ex)

rng = random.Random(seed)
truncated_rows: List[Dict[str, Any]] = []
for sig, rows in sig_to_rows.items():
if truncate_ratio >= 1.0:
truncated_rows.extend(rows)
continue
keep_n = max(1, int(len(rows) * truncate_ratio))
rng.shuffle(rows)
truncated_rows.extend(rows[:keep_n])

self.rows = truncated_rows
self.dataframe = self # preserve your API

# simple stats
counts = {sig: sum(1 for r in self.rows if r["modality_signature"] == sig) for sig in sig_to_rows}
print(f"[DEBUG] After truncation (ratio={truncate_ratio}), total {len(self.rows)}. Per-signature: {counts}")

def __len__(self):
return len(self.rows)

def __getitem__(self, idx):
return self.rows[idx]


def assert_homogeneous(batch_list: List[Dict[str, Any]]):
sigs = {b.get("modality_signature") for b in batch_list}
if len(sigs) != 1:
raise AssertionError(f"Non-homogeneous batch signatures: {sigs}")

def collate_with_guard(batch_list):
assert_homogeneous(batch_list)
return rl_collate_fn(batch_list)

def build_cfg(train_bs: int, val_bs: int, seed: int = 42):
class Dot(dict):
__getattr__ = dict.get
__setattr__ = dict.__setitem__
__delattr__ = dict.__delitem__
return Dot({
"train_batch_size": train_bs,
"val_batch_size": val_bs,
"shuffle": True,
"seed": seed,
"dataloader_num_workers": 0,
"validation_shuffle": False,
"sampler": None,
"modality_batching": {"enabled": True, "batch_size": train_bs, "drop_last": True},
"val_modality_batching": {"enabled": True, "batch_size": val_bs, "drop_last": False},
})

def build_loader(dataset, data_cfg, split: str):
sampler_or_batch = create_rl_sampler(data_cfg, dataset, split=split)
if isinstance(sampler_or_batch, BatchSampler):
return StatefulDataLoader(
dataset=dataset,
batch_sampler=sampler_or_batch,
num_workers=data_cfg["dataloader_num_workers"],
collate_fn=collate_with_guard,
)
else:
bs = data_cfg.get("train_batch_size" if split == "train" else "val_batch_size")
return StatefulDataLoader(
dataset=dataset,
sampler=sampler_or_batch,
batch_size=bs,
num_workers=data_cfg["dataloader_num_workers"],
drop_last=(split == "train"),
shuffle=False if split == "val" else False,
collate_fn=collate_with_guard,
)

def main():
ds = JsonlDataset(JSONL_PATH)
print(f"Dataset size: {len(ds)}; per-signature counts:",
{sig: sum(1 for r in ds.rows if r['modality_signature']==sig)
for sig in sorted({r['modality_signature'] for r in ds.rows})})

cfg = build_cfg(TRAIN_BS, VAL_BS, SEED)

# TRAIN
train_loader = build_loader(ds, cfg, split="train")
print("\n[TRAIN] Iteration 1")
n_train_batches = sum(1 for _ in train_loader) # iterating as you would with the train loader
print(f"train steps: {n_train_batches} (drop_last=True)")

# New epoch
train_loader2 = build_loader(ds, cfg, split="train")
n_train_batches2 = sum(1 for _ in train_loader2)
assert n_train_batches == n_train_batches2
print("[TRAIN] Iteration 2: step count consistent")

# VAL
val_loader = build_loader(ds, cfg, split="val")
print("\n[VAL] Iteration 1")
n_val_batches = sum(1 for _ in val_loader)
print(f"val steps: {n_val_batches} (drop_last=False)")

# Stateful resume check (if supported)
if hasattr(train_loader, "state_dict"):
print("\n[STATEFUL] Testing resume mid-epoch")
train_loader3 = build_loader(ds, cfg, split="train")
it = iter(train_loader3)
next(it); next(it) # consume 2
sd = train_loader3.state_dict()
train_loader4 = build_loader(ds, cfg, split="train")
train_loader4.load_state_dict(sd)
resumed = sum(1 for _ in train_loader4)
print(f"resumed batches after 2 consumed: {resumed}")

print("\nOK: StatefulDataLoader + sampler test finished.")

if __name__ == "__main__":
main()
Loading