Skip to content

Model Modification for Autoregressive Usage #12

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Merged
merged 7 commits into from
May 10, 2025
Merged
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
34 changes: 34 additions & 0 deletions assets/sample_dataset/test/test_10_metadata.json
Original file line number Diff line number Diff line change
@@ -0,0 +1,34 @@
{
"document_id": "1246772",
"id": "1246772_a2387452",
"sentence": {
"id": "a2387452",
"start": 233480,
"end": 236640,
"english": "I held my talk in July.",
"german": "Ich habe im Juli den Vortrag gehalten.",
"glosses": {
"gloss": [
"$INDEX1",
"ICH1*",
"JULI1A",
"VORTRAG1",
"ICH1"
],
"start": [
233540,
233920,
234320,
234880,
235360
],
"end": [
233620,
234040,
234520,
235180,
235580
]
}
}
}
Binary file added assets/sample_dataset/test/test_10_original.pose
Binary file not shown.
Binary file added assets/sample_dataset/test/test_10_updated.pose
Binary file not shown.
34 changes: 34 additions & 0 deletions assets/sample_dataset/test/test_11_metadata.json
Original file line number Diff line number Diff line change
@@ -0,0 +1,34 @@
{
"document_id": "1413251",
"id": "1413251_a3093382",
"sentence": {
"id": "a3093382",
"start": 593460,
"end": 595260,
"english": "Here, it should also grow slowly, that’s better.",
"german": "Hier sollte das auch langsam wachsen, das ist besser.",
"glosses": {
"gloss": [
"WIE3A*",
"HIER1",
"LANGSAM1*",
"BESSER2",
"$GEST-NM-KOPFNICKEN1^"
],
"start": [
593460,
593660,
594240,
594600,
594900
],
"end": [
593600,
593820,
594480,
594780,
595260
]
}
}
}
Binary file added assets/sample_dataset/test/test_11_original.pose
Binary file not shown.
Binary file added assets/sample_dataset/test/test_11_updated.pose
Binary file not shown.
31 changes: 31 additions & 0 deletions assets/sample_dataset/test/test_12_metadata.json
Original file line number Diff line number Diff line change
@@ -0,0 +1,31 @@
{
"document_id": "1419797",
"id": "1419797_a2743885",
"sentence": {
"id": "a2743885",
"start": 262360,
"end": 265060,
"english": "I take my hat off to him, he truly is terrific.",
"german": "Ich ziehe meinen Hut vor ihm, er ist richtig toll!",
"glosses": {
"gloss": [
"BRAVO2*",
"HUT-AB1",
"BRAVO2*",
"APPLAUS1"
],
"start": [
262460,
263040,
263960,
264560
],
"end": [
262700,
263420,
264220,
264700
]
}
}
}
Binary file added assets/sample_dataset/test/test_12_original.pose
Binary file not shown.
Binary file added assets/sample_dataset/test/test_12_updated.pose
Binary file not shown.
40 changes: 40 additions & 0 deletions assets/sample_dataset/test/test_13_metadata.json
Original file line number Diff line number Diff line change
@@ -0,0 +1,40 @@
{
"document_id": "1211515",
"id": "1211515_a3240034",
"sentence": {
"id": "a3240034",
"start": 92860,
"end": 96360,
"english": "The food was better than the stuff we got at school.",
"german": "Da hat man besseres Essen als im Heim bekommen.",
"glosses": {
"gloss": [
"BEKOMMEN3*",
"GUT1*",
"ESSEN1",
"ALS1",
"$INDEX1",
"HEIM1A*",
"$INDEX1"
],
"start": [
92960,
93480,
93660,
94060,
94360,
94640,
94760
],
"end": [
93260,
93520,
93960,
94180,
94460,
94720,
96180
]
}
}
}
Binary file added assets/sample_dataset/test/test_13_original.pose
Binary file not shown.
Binary file added assets/sample_dataset/test/test_13_updated.pose
Binary file not shown.
31 changes: 31 additions & 0 deletions assets/sample_dataset/test/test_14_metadata.json
Original file line number Diff line number Diff line change
@@ -0,0 +1,31 @@
{
"document_id": "1204691",
"id": "1204691_a2770164",
"sentence": {
"id": "a2770164",
"start": 136840,
"end": 138380,
"english": "I need to tell you something important.",
"german": "Ich muss dir etwas wichtiges sagen.",
"glosses": {
"gloss": [
"ICH1*",
"BESCHEID1A*",
"WICHTIG1*",
"WAS1A"
],
"start": [
137080,
137280,
137620,
138000
],
"end": [
137120,
137440,
137840,
138340
]
}
}
}
Binary file added assets/sample_dataset/test/test_14_original.pose
Binary file not shown.
Binary file added assets/sample_dataset/test/test_14_updated.pose
Binary file not shown.
28 changes: 28 additions & 0 deletions assets/sample_dataset/test/test_15_metadata.json
Original file line number Diff line number Diff line change
@@ -0,0 +1,28 @@
{
"document_id": "1428472",
"id": "1428472_a3228720",
"sentence": {
"id": "a3228720",
"start": 836740,
"end": 837840,
"english": "Why am I deaf?",
"german": "Warum war ich gehörlos?",
"glosses": {
"gloss": [
"ICH1",
"TAUB-GEHÖRLOS1A",
"WARUM10A*"
],
"start": [
836920,
837300,
837720
],
"end": [
837020,
837580,
837840
]
}
}
}
Binary file not shown.
Binary file added assets/sample_dataset/test/test_15_updated.pose
Binary file not shown.
28 changes: 28 additions & 0 deletions assets/sample_dataset/test/test_16_metadata.json
Original file line number Diff line number Diff line change
@@ -0,0 +1,28 @@
{
"document_id": "1582654",
"id": "1582654_a2659401",
"sentence": {
"id": "a2659401",
"start": 590240,
"end": 592860,
"english": "Yes. It was my first time flying for such a long period of time, nine hours.",
"german": "Ja. Das war das erste Mal, dass ich so weit geflogen bin, neun Stunden.",
"glosses": {
"gloss": [
"ERSTES-MAL3C",
"WEIT-SEHR1",
"STUNDE2A*"
],
"start": [
590480,
591120,
591700
],
"end": [
590920,
591520,
592460
]
}
}
}
Binary file added assets/sample_dataset/test/test_16_original.pose
Binary file not shown.
Binary file added assets/sample_dataset/test/test_16_updated.pose
Binary file not shown.
34 changes: 34 additions & 0 deletions assets/sample_dataset/test/test_1_metadata.json
Original file line number Diff line number Diff line change
@@ -0,0 +1,34 @@
{
"document_id": "1431642-17445220-17471818",
"id": "1431642-17445220-17471818_a3952290",
"sentence": {
"id": "a3952290",
"start": 60040,
"end": 62560,
"english": "There was this one hearing guy who said, “Gee!",
"german": "Da war auch ein Hörender dabei, der meinte: „Ach, Mensch!",
"glosses": {
"gloss": [
"$NUM-EINER1A:1d",
"HÖREND1B",
"DABEI1B*",
"SAGEN1*",
"$GEST-ABWINKEN1^"
],
"start": [
60320,
60700,
61120,
61680,
62120
],
"end": [
60480,
60920,
61300,
61860,
62560
]
}
}
}
Binary file added assets/sample_dataset/test/test_1_original.pose
Binary file not shown.
Binary file added assets/sample_dataset/test/test_1_updated.pose
Binary file not shown.
40 changes: 40 additions & 0 deletions assets/sample_dataset/test/test_2_metadata.json
Original file line number Diff line number Diff line change
@@ -0,0 +1,40 @@
{
"document_id": "1250972",
"id": "1250972_a2558205",
"sentence": {
"id": "a2558205",
"start": 122820,
"end": 126340,
"english": "After that another firm supported me in my job hunt.",
"german": "Danach hat mich eine andere Firma bei der Jobsuche unterstützt.",
"glosses": {
"gloss": [
"KOMMEN1^*",
"ANDERS1*",
"FIRMA1A",
"HELFEN1*",
"AUF-PERSON1*",
"ICH2*",
"UNTERSTÜTZEN1A"
],
"start": [
123060,
123480,
124420,
125000,
125520,
125760,
125920
],
"end": [
123360,
123840,
124560,
125380,
125600,
125800,
126240
]
}
}
}
Binary file added assets/sample_dataset/test/test_2_original.pose
Binary file not shown.
Binary file added assets/sample_dataset/test/test_2_updated.pose
Binary file not shown.
67 changes: 67 additions & 0 deletions assets/sample_dataset/test/test_3_metadata.json
Original file line number Diff line number Diff line change
@@ -0,0 +1,67 @@
{
"document_id": "1177002",
"id": "1177002_a3259474",
"sentence": {
"id": "a3259474",
"start": 405260,
"end": 411520,
"english": "Of course with artificial insemination, one could maybe request specific things.",
"german": "Natürlich kann es sein, dass man sich bei der künstlichen Befruchtung dann auch Dinge wünschen kann.",
"glosses": {
"gloss": [
"$ORAL^",
"KLAR1B",
"WENN1A",
"BEISPIEL1*",
"TYPISCH1*",
"$INDEX1",
"KÜNSTLICH5*",
"BEFRUCHTUNG2*",
"SOLL1*",
"EMPFINDUNG1*",
"BESTIMMEN1*",
"KANN2B*",
"ICH1*",
"WÜNSCHEN1B",
"AUSWAHL1C*",
"AUFZÄHLEN1C*"
],
"start": [
405320,
405900,
406260,
406640,
406860,
407180,
407860,
408460,
409080,
409260,
409600,
409840,
410340,
410660,
411000,
411000
],
"end": [
405840,
406040,
406340,
406740,
407120,
407260,
408060,
408960,
409100,
409440,
409620,
410060,
410420,
410780,
411520,
411520
]
}
}
}
Binary file added assets/sample_dataset/test/test_3_original.pose
Binary file not shown.
Binary file added assets/sample_dataset/test/test_3_updated.pose
Binary file not shown.
40 changes: 40 additions & 0 deletions assets/sample_dataset/test/test_4_metadata.json
Original file line number Diff line number Diff line change
@@ -0,0 +1,40 @@
{
"document_id": "1585089",
"id": "1585089_a2474933",
"sentence": {
"id": "a2474933",
"start": 54540,
"end": 57800,
"english": "I would think it’s more psychological, right?",
"german": "Ich denke, es war eher psychisch bedingt, oder?",
"glosses": {
"gloss": [
"ICH1",
"DENKEN1A",
"$INDEX1",
"DEPRESSION3^*",
"$GEST-OFF^",
"ODER1",
"$GEST-OFF^"
],
"start": [
54620,
55000,
55100,
55640,
56260,
56680,
57160
],
"end": [
54720,
55040,
55160,
56080,
56500,
56880,
57800
]
}
}
}
Binary file added assets/sample_dataset/test/test_4_original.pose
Binary file not shown.
Binary file added assets/sample_dataset/test/test_4_updated.pose
Binary file not shown.
40 changes: 40 additions & 0 deletions assets/sample_dataset/test/test_5_metadata.json
Original file line number Diff line number Diff line change
@@ -0,0 +1,40 @@
{
"document_id": "1413925",
"id": "1413925_a2795708",
"sentence": {
"id": "a2795708",
"start": 158440,
"end": 160660,
"english": "I was at home in the morning and saw it.",
"german": "Da war ich vormittags zu Hause und habe es gesehen.",
"glosses": {
"gloss": [
"VORMITTAG1",
"ICH1*",
"ZUHAUSE2*",
"ICH1",
"FERNBEDIENUNG1*",
"ZUSCHAUEN2*",
"$GEST^"
],
"start": [
158660,
159000,
159140,
159600,
159780,
159960,
160400
],
"end": [
158960,
159060,
159380,
159640,
159860,
160120,
160660
]
}
}
}
Binary file added assets/sample_dataset/test/test_5_original.pose
Binary file not shown.
Binary file added assets/sample_dataset/test/test_5_updated.pose
Binary file not shown.
34 changes: 34 additions & 0 deletions assets/sample_dataset/test/test_6_metadata.json
Original file line number Diff line number Diff line change
@@ -0,0 +1,34 @@
{
"document_id": "1187154",
"id": "1187154_a3077606",
"sentence": {
"id": "a3077606",
"start": 155180,
"end": 159140,
"english": "She looks back at him and goes, “What do you think, am I the baker or what?",
"german": "Sie blickt zurück und antwortet: „Glaubst du, ich bin Bäcker oder was?",
"glosses": {
"gloss": [
"GLAUBEN2A",
"DU1*",
"PERSON1",
"BACKEN3A*",
"WAS1A"
],
"start": [
155840,
155960,
157500,
157720,
157980
],
"end": [
155960,
156080,
157600,
157860,
158180
]
}
}
}
Binary file added assets/sample_dataset/test/test_6_original.pose
Binary file not shown.
Binary file added assets/sample_dataset/test/test_6_updated.pose
Binary file not shown.
22 changes: 22 additions & 0 deletions assets/sample_dataset/test/test_7_metadata.json
Original file line number Diff line number Diff line change
@@ -0,0 +1,22 @@
{
"document_id": "1418858",
"id": "1418858_a2929745",
"sentence": {
"id": "a2929745",
"start": 66140,
"end": 67540,
"english": "Oh I don’t know.",
"german": "Oh, keine Ahnung.",
"glosses": {
"gloss": [
"$GEST-NM^"
],
"start": [
66140
],
"end": [
67200
]
}
}
}
Binary file added assets/sample_dataset/test/test_7_original.pose
Binary file not shown.
Binary file added assets/sample_dataset/test/test_7_updated.pose
Binary file not shown.
40 changes: 40 additions & 0 deletions assets/sample_dataset/test/test_8_metadata.json
Original file line number Diff line number Diff line change
@@ -0,0 +1,40 @@
{
"document_id": "1289462",
"id": "1289462_a2395820",
"sentence": {
"id": "a2395820",
"start": 155940,
"end": 158800,
"english": "It's been really bad for half a year now.",
"german": "Seit einem Jahr ist das bei mir wirklich schlimm.",
"glosses": {
"gloss": [
"$GEST-ABWINKEN1^*",
"$GEST^*",
"SCHLIMM4*",
"ICH2*",
"VOR1H",
"$NUM-EINER1A:1d",
"VERGANGENHEIT1^"
],
"start": [
156040,
156540,
156960,
157260,
157560,
157820,
158120
],
"end": [
156240,
156720,
157120,
157380,
157640,
157940,
158140
]
}
}
}
Binary file added assets/sample_dataset/test/test_8_original.pose
Binary file not shown.
Binary file added assets/sample_dataset/test/test_8_updated.pose
Binary file not shown.
37 changes: 37 additions & 0 deletions assets/sample_dataset/test/test_9_metadata.json
Original file line number Diff line number Diff line change
@@ -0,0 +1,37 @@
{
"document_id": "1291572",
"id": "1291572_a2663101",
"sentence": {
"id": "a2663101",
"start": 545380,
"end": 547520,
"english": "Then they are one of the same flesh and blood.",
"german": "Sie sind dann ein Fleisch und Blut.",
"glosses": {
"gloss": [
"$NUM-EINER1A:1d*",
"$NUM-EINER1A:1d*",
"FLEISCH1B",
"UND5*",
"BLUT1D",
"ZUSAMMENHANG1A^*"
],
"start": [
545520,
546040,
546460,
546680,
546880,
547200
],
"end": [
545580,
546160,
546560,
546740,
547060,
547480
]
}
}
}
Binary file added assets/sample_dataset/test/test_9_original.pose
Binary file not shown.
Binary file added assets/sample_dataset/test/test_9_updated.pose
Binary file not shown.
21 changes: 10 additions & 11 deletions fluent_pose_synthesis/config/default.json
Original file line number Diff line number Diff line change
@@ -1,32 +1,31 @@
{
"arch": {
"decoder": "trans_enc",
"latent_dim": 256,
"ff_size": 256,
"latent_dim": 512,
"ff_size": 1024,
"num_heads": 4,
"num_layers": 2,
"clip_len": 40,
"num_layers": 8,
"chunk_len": 40,
"keypoints": 178,
"dims": 3,
"dropout": 0.1,
"dropout": 0.2,
"activation": "gelu",
"ablation": null,
"legacy": false
},
"diff": {
"noise_schedule": "cosine",
"diffusion_steps": 4,
"diffusion_steps": 32,
"sigma_small": true
},
"trainer": {
"epoch": 50,
"epoch": 300,
"lr": 1e-4,
"batch_size": 64,
"cond_mask_prob": 0.15,
"batch_size": 1024,
"cond_mask_prob": 0,
"use_loss_mse": true,
"use_loss_vel": true,
"use_loss_3d": true,
"workers": 4,
"load_num": 64
"load_num": 200
}
}
16 changes: 10 additions & 6 deletions fluent_pose_synthesis/config/option.py
Original file line number Diff line number Diff line change
@@ -6,7 +6,7 @@
def add_model_args(parser):
parser.add_argument('--decoder', type=str, default='trans_enc', help='Decoder type.')
parser.add_argument('--latent_dim', type=int, default=256, help='Transformer/GRU latent dimension.')
parser.add_argument('--ff_size', type=int, default=1024, help='Feed-forward size.')
parser.add_argument('--ff_size', type=int, default=512, help='Feed-forward size.')
parser.add_argument('--num_heads', type=int, default=4, help='Number of attention heads.')
parser.add_argument('--num_layers', type=int, default=4, help='Number of model layers.')

@@ -16,14 +16,16 @@ def add_diffusion_args(parser):
parser.add_argument('--sigma_small', action='store_true', help='Use small sigma values.')

def add_train_args(parser):
parser.add_argument('--epoch', type=int, default=100, help='Number of training epochs.')
parser.add_argument('--lr', type=float, default=1e-4, help='Learning rate.')
parser.add_argument('--epoch', type=int, default=300, help='Number of training epochs.')
parser.add_argument('--lr', type=float, default=0.00005, help='Learning rate.')
parser.add_argument('--lr_anneal_steps', type=int, default=0, help='Annealing steps.')
parser.add_argument('--weight_decay', type=float, default=0.0, help='Weight decay.')
parser.add_argument('--batch_size', type=int, default=512, help='Batch size.')
parser.add_argument('--cond_mask_prob', type=float, default=0.15, help='Conditioning mask probability.')
parser.add_argument('--weight_decay', type=float, default=0.00, help='Weight decay.')
parser.add_argument('--batch_size', type=int, default=1024, help='Batch size.')
parser.add_argument('--cond_mask_prob', type=float, default=0, help='Conditioning mask probability.')
parser.add_argument('--workers', type=int, default=4, help='Data loader workers.')
parser.add_argument('--ema', default=False, type=bool, help='Use Exponential Moving Average (EMA) for model parameters.')
parser.add_argument('--lambda_vel', type=float, default=1.0, help='Weight factor for the velocity loss term.')
parser.add_argument('--load_num', type=int, default=-1, help='Number of models to load.')


def config_parse(args):
@@ -51,6 +53,8 @@ def config_parse(args):
config.trainer.cond_mask_prob = args.cond_mask_prob
config.trainer.workers = args.workers
config.trainer.save_freq = int(config.trainer.epoch // 5)
config.trainer.lambda_vel = args.lambda_vel
config.trainer.load_num = args.load_num


# Save directory
396 changes: 196 additions & 200 deletions fluent_pose_synthesis/core/models.py

Large diffs are not rendered by default.

316 changes: 181 additions & 135 deletions fluent_pose_synthesis/core/training.py

Large diffs are not rendered by default.

220 changes: 141 additions & 79 deletions fluent_pose_synthesis/data/load_data.py
Original file line number Diff line number Diff line change
@@ -1,5 +1,6 @@
import time
import json
import random
from pathlib import Path
from typing import Any, Dict

@@ -8,28 +9,29 @@
from torch.utils.data import Dataset
from pose_format import Pose
from pose_format.torch.masked.collator import zero_pad_collator
from pose_anonymization.data.normalization import normalize_mean_std


class SignLanguagePoseDataset(Dataset):
def __init__(
self,
data_dir: Path,
split: str,
fluent_frames: int,
chunk_len: int,
dtype=np.float32,
limited_num: int = -1,
):
"""
Args:
data_dir (Path): Root directory where the data is saved. Each split should be in its own subdirectory.
split (str): Dataset split name, including "train", "validation", and "test".
fluent_frames (int): Frames numbers from the fluent (target) sequence to use as target.
chunk_len (int): Frames numbers from the fluent (target) sequence to use as target (chunk length).
dtype: Data type for the arrays, default is np.float32.
limited_num (int): Limit the number of samples to load; default -1 loads all samples.
"""
self.data_dir = data_dir
self.split = split
self.fluent_frames = fluent_frames
self.chunk_len = chunk_len
self.dtype = dtype

# Store only file paths for now, load data on-the-fly
@@ -38,18 +40,12 @@ def __init__(
split_dir = self.data_dir / split
fluent_files = sorted(list(split_dir.glob(f"{split}_*_original.pose")))
if limited_num > 0:
fluent_files = fluent_files[
:limited_num
] # Limit the number of samples to load
fluent_files = fluent_files[:limited_num] # Limit the number of samples to load

for fluent_file in fluent_files:
# Construct corresponding disfluent and metadata file paths based on the file name
disfluent_file = fluent_file.with_name(
fluent_file.name.replace("_original.pose", "_updated.pose")
)
metadata_file = fluent_file.with_name(
fluent_file.name.replace("_original.pose", "_metadata.json")
)
disfluent_file = fluent_file.with_name(fluent_file.name.replace("_original.pose", "_updated.pose"))
metadata_file = fluent_file.with_name(fluent_file.name.replace("_original.pose", "_metadata.json"))
self.examples.append(
{
"fluent_path": fluent_file,
@@ -68,9 +64,7 @@ def __init__(
first_pose = Pose.read(f.read())
self.pose_header = first_pose.header
except Exception as e:
print(
f"[WARNING] Failed to read pose_header from {first_fluent_path}: {e}"
)
print(f"[WARNING] Failed to read pose_header from {first_fluent_path}: {e}")
self.pose_header = None
else:
self.pose_header = None
@@ -84,10 +78,11 @@ def __len__(self) -> int:
def __getitem__(self, idx: int) -> Dict[str, Any]:
"""
Retrieves a sample from the dataset. For each sample, load the entire disfluent sequence as condition,
and randomly sample a cip from the fluent sequence of fixed length (fluent_frames) as target.
and randomly sample a clip from the fluent sequence of fixed length (chunk_len) as target.
Args:
idx (int): Index of the sample to retrieve.
"""

sample = self.examples[idx]

# Load pose sequences and metadata from disk
@@ -98,47 +93,89 @@ def __getitem__(self, idx: int) -> Dict[str, Any]:
with open(sample["metadata_path"], "r", encoding="utf-8") as f:
metadata = json.load(f)

# Apply in-place normalization
fluent_pose.normalize()
disfluent_pose.normalize()
# print(f"[DEBUG][Before Norm] Fluent raw data mean: {fluent_pose.body.data.mean(axis=(0, 1, 2))} std {fluent_pose.body.data.std(axis=(0, 1, 2))}")
# print(f"[DEBUG][Before Norm] Disfluent raw data mean: {disfluent_pose.body.data.mean(axis=(0, 1, 2))} std {disfluent_pose.body.data.std(axis=(0, 1, 2))}")

fluent_data = fluent_pose.body.data.astype(self.dtype)
# Use the entire disfluent sequence as condition
disfluent_seq = disfluent_pose.body.data.astype(self.dtype)
disfluent_mask = disfluent_pose.body.mask
# Normalize the pose data
fluent_pose = normalize_mean_std(fluent_pose)
disfluent_pose = normalize_mean_std(disfluent_pose)

# print(f"DEBUG][After Norm] Fluent normalized data mean:: {fluent_pose.body.data.mean(axis=(0, 1, 2))} std {fluent_pose.body.data.std(axis=(0, 1, 2))}")
# print(f"[DEBUG][After Norm] Disfluent normalized data mean: {disfluent_pose.body.data.mean(axis=(0, 1, 2))} std {disfluent_pose.body.data.std(axis=(0, 1, 2))}")

fluent_data = np.array(fluent_pose.body.data.astype(self.dtype))
fluent_mask = fluent_pose.body.data.mask
disfluent_data = np.array(disfluent_pose.body.data.astype(self.dtype))

fluent_length = len(fluent_data)
# Dynamic windowing: randomly select a window of length fluent_frames
if fluent_length > self.fluent_frames:
valid_windows = [
start
for start in range(0, fluent_length - self.fluent_frames + 1)
if np.any(fluent_data[start : start + self.fluent_frames] != 0)
]
start = np.random.choice(valid_windows) if valid_windows else 0
fluent_clip = fluent_data[start : start + self.fluent_frames]

# 1. Randomly sample the start index for the fluent (target) chunk
if fluent_length <= self.chunk_len:
start_idx = 0
target_len = fluent_length
history_len = 0
else:
fluent_clip = fluent_data # Will be padded later using collator
start_idx = random.randint(0, fluent_length - self.chunk_len)
target_len = self.chunk_len
history_len = start_idx

# 2. Extract target chunk (y_k) and history chunk (y_1, ..., y_{k-1})
target_chunk = fluent_data[start_idx : start_idx + target_len]
target_mask = fluent_mask[start_idx : start_idx + target_len]

if history_len > 0:
history_chunk = fluent_data[:history_len]
else:
# MODIFICATION: Force minimum length of 1 for previous_output if empty
history_chunk = np.zeros((1,) + fluent_data.shape[1:], dtype=self.dtype) # create a single empty frame
# The purpose of this is to ensure the current collate_fn works
# else:
# # No history chunk available, create an empty array with time dimension 0
# history_chunk = np.empty((0,) + fluent_data.shape[1:], dtype=self.dtype)

# 3. Prepare the entire disfluent sequence as condition
disfluent_seq = disfluent_data

# 4. Pad target chunk if its actual length is less than chunk_len
if target_chunk.shape[0] < self.chunk_len:
pad_len = self.chunk_len - target_chunk.shape[0]
# Padding 0s for target chunk
padding_shape_data = (pad_len,) + target_chunk.shape[1:]
target_padding = np.zeros(padding_shape_data, dtype=self.dtype)
target_chunk = np.concatenate([target_chunk, target_padding], axis=0)
# Padding for mask (True for masked)
mask_padding = np.ones((pad_len,) + target_mask.shape[1:], dtype=bool)
target_mask = np.concatenate([target_mask, mask_padding], axis=0)

# 5. Convert numpy arrays to torch tensors
target_chunk = torch.from_numpy(target_chunk.astype(np.float32))
history_chunk = torch.from_numpy(history_chunk.astype(np.float32))
disfluent_seq = torch.from_numpy(disfluent_seq.astype(np.float32))
target_mask = torch.from_numpy(target_mask) # Bool tensor

# Frame-level mask generation
target_mask = np.any(fluent_clip != 0, axis=(1, 2, 3)) # shape: [T]
# 6. Squeeze person dimension
target_chunk = target_chunk.squeeze(1) # (T_chunk, K, D)
history_chunk = history_chunk.squeeze(1) # (T_hist, K, D)
disfluent_seq = disfluent_seq.squeeze(1) # (T_disfl, K, D)
target_mask = target_mask.squeeze(1) # (T_chunk, K, D)

# 7. Create conditions dictionary
# Later, zero_pad_collator will handle padding T_disfl and T_hist across the batch
conditions = {
"input_sequence": disfluent_seq, # (T_disfl, K, D)
"previous_output": history_chunk, # (T_hist, K, D)
"target_mask": target_mask # (T_chunk, K, D)
}

# print(f"DEBUG Dataset idx {idx}:")
# print(f" target_chunk shape: {target_chunk.shape}")
# print(f" input_sequence shape: {disfluent_seq.shape}")
# print(f" previous_output shape: {history_chunk.shape}")
# print(f" target_mask shape: {target_mask.shape}")

return {
"data": torch.tensor(
fluent_clip, dtype=torch.float32
), # Fluent target clip
"conditions": {
"input_sequence": torch.tensor(
disfluent_seq, dtype=torch.float32
), # Full disfluent input
"input_mask": torch.tensor(
disfluent_mask, dtype=torch.bool
), # Disfluent sequence mask
"target_mask": torch.tensor(
target_mask, dtype=torch.bool
), # Per-frame valid mask
"metadata": metadata,
},
"data": target_chunk, # (T_chunk, K, D)
"conditions": conditions,
}


@@ -148,9 +185,9 @@ def example_dataset():
"""
# Create an instance of the dataset
dataset = SignLanguagePoseDataset(
data_dir=Path("/scratch/ronli/output"),
data_dir=Path("/scratch/ronli/fluent-pose-synthesis/pose_data/output"),
split="train",
fluent_frames=50,
chunk_len=40,
limited_num=128,
)

@@ -161,37 +198,62 @@ def example_dataset():
shuffle=True,
num_workers=0,
drop_last=False,
pin_memory=True,
pin_memory=False,
collate_fn=zero_pad_collator,
)

# Flag to indicate whether to display batch information
display_batch_info = True
# Flag to indicate whether to measure data loading time
measure_loading_time = True

if display_batch_info:
# Display shapes of a batch for debugging purposes
batch = next(iter(dataloader))
print("Batch size:", len(batch))
print("Normalized target clip:", batch["data"].shape)
print("Input sequence:", batch["conditions"]["input_sequence"].shape)
print("Input mask:", batch["conditions"]["input_mask"].shape)
print("Target mask:", batch["conditions"]["target_mask"].shape)

if measure_loading_time:
loading_times = []
start_time = time.time()
for batch in dataloader:
end_time = time.time()
batch_loading_time = end_time - start_time
print(f"Data loading time for each iteration: {batch_loading_time:.4f}s")
loading_times.append(batch_loading_time)
start_time = end_time
avg_loading_time = sum(loading_times) / len(loading_times)
print(f"Average data loading time: {avg_loading_time:.4f}s")
print(f"Total data loading time: {sum(loading_times):.4f}s")
print(f"\n--- Example Batch Info (Batch Size: {dataloader.batch_size}) ---")

batch = next(iter(dataloader))
print("Batch Keys:", batch.keys())
print("Conditions Keys:", batch['conditions'].keys())

print("\nShapes:")
print(f" data (Target Chunk): {batch['data'].shape}")
print(f" conditions['input_sequence'] (Disfluent): {batch['conditions']['input_sequence'].shape} ")
print(f" conditions['previous_output'] (History): {batch['conditions']['previous_output'].shape} ")
print(f" conditions['target_mask']: {batch['conditions']['target_mask'].shape}")

print("\nNormalization Stats (Shapes):")
print(f" Fluent Mean: {dataset.fluent_mean.shape}")
print(f" Fluent Std: {dataset.fluent_std.shape}")
print(f" Disfluent Mean: {dataset.disfluent_mean.shape}")
print(f" Disfluent Std: {dataset.disfluent_std.shape}")

print("\nSample Values (first element of first sequence):")
print(f" Target Chunk (first 5 flattened): {batch['data'][0].flatten()[:5]}")
# Check if history chunk is not empty
if batch["conditions"]["previous_output"].shape[1] > 0:
print(f" History Chunk (first 5 flattened): {batch['conditions']['previous_output'][0].flatten()[:5]}")
else:
print(" History Chunk: Empty")
print(f" Disfluent Seq (first 5 flattened): {batch['conditions']['input_sequence'][0].flatten()[:5]}")
print(f" Target Mask (first 5 flattened): {batch['conditions']['target_mask'][0].flatten()[:5]}")


# if __name__ == '__main__':
# example_dataset()


# Example Output:
# Dataset initialized with 128 samples. Split: train
# Batch Keys: dict_keys(['data', 'conditions'])
# Conditions Keys: dict_keys(['input_sequence', 'previous_output', 'target_mask'])

# Shapes:
# data (Target Chunk): torch.Size([32, 40, 178, 3])
# conditions['input_sequence'] (Disfluent): torch.Size([32, 359, 178, 3])
# conditions['previous_output'] (History): torch.Size([32, 110, 178, 3])
# conditions['target_mask']: torch.Size([32, 40, 178, 3])

# Normalization Stats (Shapes):
# Fluent Mean: torch.Size([1, 178, 3])
# Fluent Std: torch.Size([1, 178, 3])
# Disfluent Mean: torch.Size([1, 178, 3])
# Disfluent Std: torch.Size([1, 178, 3])

# Sample Values (first element of first sequence):
# Target Chunk (first 5 flattened): tensor([ 9.0694e-02, 7.7781e-01, -7.0343e+02, 1.9091e-01, -8.7535e-01])
# History Chunk (first 5 flattened): tensor([0., 0., 0., 0., 0.])
# Disfluent Seq (first 5 flattened): tensor([ 0.1327, 1.0505, -1.7174, 0.2764, -0.7866])
# Target Mask (first 5 flattened): tensor([False, False, False, False, False])
315 changes: 315 additions & 0 deletions fluent_pose_synthesis/infer_autoregressive.py

Large diffs are not rendered by default.

2 changes: 1 addition & 1 deletion fluent_pose_synthesis/tests/test_models.py
Original file line number Diff line number Diff line change
@@ -29,7 +29,7 @@ def test_base_output_shape(arch, seq_len, batch_size):

model = SignLanguagePoseDiffusion(
input_feats=534,
clip_len=seq_len,
chunk_len=seq_len,
keypoints=178,
dims=3,
latent_dim=256,
8 changes: 4 additions & 4 deletions fluent_pose_synthesis/tests/test_overfit.py
Original file line number Diff line number Diff line change
@@ -93,7 +93,7 @@ def create_minimal_config(device="cpu"):
arch=Namespace(
keypoints=178,
dims=3,
clip_len=40,
chunk_len=40,
latent_dim=32,
ff_size=64,
num_layers=2,
@@ -119,7 +119,7 @@ def test_overfit_toy_batch():

batch = get_toy_batch(
batch_size=config.trainer.batch_size,
seq_len=config.arch.clip_len,
seq_len=config.arch.chunk_len,
keypoints=config.arch.keypoints,
)

@@ -138,7 +138,7 @@ def test_overfit_toy_batch():

model = SignLanguagePoseDiffusion(
input_feats=config.arch.keypoints * config.arch.dims,
clip_len=config.arch.clip_len,
chunk_len=config.arch.chunk_len,
keypoints=config.arch.keypoints,
dims=config.arch.dims,
latent_dim=config.arch.latent_dim,
@@ -205,7 +205,7 @@ def test_overfit_toy_batch():
print(f"out1 shape: {out1.shape}, out2 shape: {out2.shape}")
expected_shape = (
1,
config.arch.clip_len,
config.arch.chunk_len,
config.arch.keypoints,
config.arch.dims,
)
20 changes: 11 additions & 9 deletions fluent_pose_synthesis/train.py
Original file line number Diff line number Diff line change
@@ -59,7 +59,7 @@ def train(
train_dataset = SignLanguagePoseDataset(
data_dir=config.data,
split="train",
fluent_frames=config.arch.clip_len,
chunk_len=config.arch.chunk_len,
dtype=np_dtype,
limited_num=config.trainer.load_num,
)
@@ -76,15 +76,15 @@ def train(

logger.info(
f"Training Dataset includes {len(train_dataset)} samples, "
f"with {config.arch.clip_len} fluent frames per sample."
f"with {config.arch.chunk_len} fluent frames per sample."
)

diffusion = create_gaussian_diffusion(config)
input_feats = config.arch.keypoints * config.arch.dims

model = SignLanguagePoseDiffusion(
input_feats=input_feats,
clip_len=config.arch.clip_len,
chunk_len=config.arch.chunk_len,
keypoints=config.arch.keypoints,
dims=config.arch.dims,
latent_dim=config.arch.latent_dim,
@@ -100,6 +100,7 @@ def train(
device=config.device,
).to(config.device)

logger.info(f"Model: {model}")
trainer = PoseTrainingPortal(
config, model, diffusion, train_dataloader, logger, tb_writer
)
@@ -133,7 +134,8 @@ def main():
parser.add_argument(
"-i",
"--data",
default="/scratch/ronli/output",
default="assets/sample_dataset",
# default="/pose_data/output",
type=str,
help="Path to dataset folder",
)
@@ -143,7 +145,7 @@ def main():
parser.add_argument(
"-s",
"--save",
default="./save",
default="save/debug_run",
type=str,
help="Directory to save model and logs",
)
@@ -162,15 +164,15 @@ def main():
config.save = Path(config.save)

if args.cluster:
config.data = Path("/scratch/ronli/output") / args.data
config.data = Path("/scratch/ronli/pose_data/output") / args.data
config.save = Path("/scratch/ronli/save") / args.name

# Debug mode settings
if "debug" in args.name:
config.trainer.workers = 1
config.trainer.load_num = 16
config.trainer.batch_size = 16
config.trainer.epoch = 100
config.trainer.load_num = -1
config.trainer.batch_size = 32
config.trainer.epoch = 2000

# Handle existing folder
if (