Skip to content

Commit 92ca72f

Browse files
committed
Add support for DROID and LeRobot datasets with training scripts and requirements
- Created requirements-droid.txt for DROID dataset dependencies. - Created requirements-lerobot.txt for LeRobot dataset dependencies. - Added train_droid_example.sh for training on DROID dataset. - Added train_pusht.sh for training on LeRobot PushT dataset. - Implemented train_action.sh as a universal training script for both datasets. - Developed train_action_universal.py to handle action-conditioned training across datasets.
1 parent 3fa2de8 commit 92ca72f

17 files changed

+2607
-33
lines changed

INSTALL.md

Lines changed: 328 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -101,6 +101,117 @@ If you encounter issues with video loading, ensure decord2 is properly installed
101101
pip install --upgrade decord2
102102
```
103103

104+
## Dataset Setup
105+
106+
### Downloading Something-Something-v2 Dataset
107+
108+
The Something-Something-v2 (SSv2) dataset is required for training the video classifier. Follow these steps to download it:
109+
110+
#### 1. Register and Request Access
111+
112+
1. Visit the [20BN Something-Something Dataset page](https://developer.qualcomm.com/software/ai-datasets/something-something)
113+
2. Create an account or sign in with Qualcomm Developer credentials
114+
3. Accept the terms and conditions
115+
4. Request access to the Something-Something-v2 dataset
116+
117+
**Note**: The dataset is hosted by Qualcomm and requires registration. Access is typically granted within 24-48 hours.
118+
119+
#### 2. Download the Dataset
120+
121+
Once approved, you'll receive download links. The dataset consists of:
122+
123+
- **Videos**: ~220GB compressed, ~500GB uncompressed
124+
- 168,913 training videos
125+
- 24,777 validation videos
126+
- 27,157 test videos (labels not publicly available)
127+
- **Labels**: JSON files with annotations
128+
- `train.json`: Training annotations
129+
- `validation.json`: Validation annotations
130+
- `labels.json`: Class label mappings (174 action classes)
131+
132+
**Download structure**:
133+
```
134+
20bn-something-something-v2/
135+
├── 20bn-something-something-v2-?? (video archives)
136+
└── labels/ (annotation files)
137+
```
138+
139+
#### 3. Extract Videos
140+
141+
After downloading, extract the video archives:
142+
143+
```bash
144+
# Create videos directory
145+
mkdir -p videos/20bn-something-something-v2
146+
147+
# Extract all parts (this may take a while)
148+
cd videos
149+
cat 20bn-something-something-v2-?? | tar -xzv
150+
151+
# Verify extraction
152+
ls 20bn-something-something-v2/ | wc -l # Should show ~220,847 .webm files
153+
```
154+
155+
#### 4. Organize Labels
156+
157+
Create a labels directory with the annotation files:
158+
159+
```bash
160+
mkdir -p videos/labels
161+
# Move or copy the JSON files
162+
mv train.json validation.json labels.json videos/labels/
163+
```
164+
165+
#### 5. Expected Directory Structure
166+
167+
After setup, your directory should look like:
168+
169+
```
170+
videos/
171+
├── 20bn-something-something-v2/
172+
│ ├── 1.webm
173+
│ ├── 2.webm
174+
│ ├── ...
175+
│ └── 220847.webm
176+
└── labels/
177+
├── train.json
178+
├── validation.json
179+
└── labels.json
180+
```
181+
182+
#### 6. Verify Dataset
183+
184+
Run a quick verification to ensure the dataset is properly set up:
185+
186+
```bash
187+
# Check video count
188+
find videos/20bn-something-something-v2 -name "*.webm" | wc -l
189+
190+
# Check label files
191+
for f in videos/labels/{train,validation,labels}.json; do
192+
echo "Checking $f..."
193+
python -c "import json; data=json.load(open('$f')); print(f' Entries: {len(data)}')"
194+
done
195+
```
196+
197+
Expected output:
198+
- Videos: ~220,847 files
199+
- train.json: ~168,913 entries
200+
- validation.json: ~24,777 entries
201+
- labels.json: 174 entries
202+
203+
#### Alternative: Using Subset for Testing
204+
205+
For quick testing without downloading the full dataset:
206+
207+
```bash
208+
# Use the --subset-size flag when training
209+
python train_ssv2_classifier.py \
210+
--videos-dir videos/20bn-something-something-v2 \
211+
--labels-dir videos/labels \
212+
--subset-size 1000 # Use only 1000 samples
213+
```
214+
104215
### Memory Issues
105216

106217
If you encounter out-of-memory errors during training:
@@ -171,7 +282,223 @@ To remove the package:
171282
pip uninstall vjepa2-mlx
172283
```
173284

174-
## Next Steps
285+
## Fine-tuning the Model
286+
287+
### Overview
288+
289+
The V-JEPA 2 MLX training pipeline uses a **frozen encoder** approach where:
290+
- The pretrained V-JEPA 2 encoder remains **frozen** (no gradient updates)
291+
- Only the **attentive classifier head** is trained
292+
- This approach is efficient, fast, and requires less memory
293+
294+
### Quick Start Fine-tuning
295+
296+
#### 1. Using the Training Script
297+
298+
```bash
299+
# Basic training command
300+
python train_ssv2_classifier.py \
301+
--videos-dir videos/20bn-something-something-v2 \
302+
--labels-dir videos/labels \
303+
--pretrained-weights weights/vitl_mlx.safetensors \
304+
--output-dir output_ssv2_classifier \
305+
--batch-size 4 \
306+
--num-epochs 10 \
307+
--use-wandb
308+
```
309+
310+
#### 2. Using the Configuration File
311+
312+
Edit `configs/train/ssv2_classifier_default.yaml` and run:
313+
314+
```bash
315+
python train_ssv2_classifier.py --config configs/train/ssv2_classifier_default.yaml
316+
```
317+
318+
#### 3. Using the Shell Script
319+
320+
```bash
321+
# Make executable (first time only)
322+
chmod +x scripts/train_ssv2.sh
323+
324+
# Run with defaults
325+
./scripts/train_ssv2.sh
326+
327+
# Run with custom settings
328+
BATCH_SIZE=8 NUM_EPOCHS=20 USE_WANDB=true ./scripts/train_ssv2.sh
329+
```
330+
331+
### Fine-tuning Configuration
332+
333+
#### Key Hyperparameters
334+
335+
Edit `configs/train/ssv2_classifier_default.yaml` to customize:
336+
337+
```yaml
338+
training:
339+
batch_size: 4 # Adjust based on memory (2-8)
340+
num_epochs: 10 # 10-30 for production
341+
learning_rate: 0.001 # 1e-3 to 5e-4 typical range
342+
weight_decay: 0.05 # AdamW regularization
343+
warmup_epochs: 1 # LR warmup period
344+
gradient_accumulation_steps: 1 # For larger effective batch
345+
save_every_steps: 1000 # Checkpoint frequency
346+
```
347+
348+
#### Memory Optimization
349+
350+
For limited memory (16GB RAM):
351+
352+
```bash
353+
python train_ssv2_classifier.py \
354+
--batch-size 2 \
355+
--gradient-accumulation-steps 4 # Effective batch size = 8
356+
```
357+
358+
For more memory (32GB+ RAM):
359+
360+
```bash
361+
python train_ssv2_classifier.py \
362+
--batch-size 8 \
363+
--gradient-accumulation-steps 1
364+
```
365+
366+
#### Quick Testing with Subset
367+
368+
Test your setup with a small data subset:
369+
370+
```bash
371+
python train_ssv2_classifier.py \
372+
--videos-dir videos/20bn-something-something-v2 \
373+
--labels-dir videos/labels \
374+
--pretrained-weights weights/vitl_mlx.safetensors \
375+
--subset-size 1000 \
376+
--num-epochs 2 \
377+
--verbose
378+
```
379+
380+
### Advanced Fine-tuning
381+
382+
#### Resume from Checkpoint
383+
384+
```bash
385+
python train_ssv2_classifier.py \
386+
--resume-from output_ssv2_classifier/classifier_step_5000.safetensors \
387+
--output-dir output_ssv2_classifier
388+
```
389+
390+
Or in config:
391+
```yaml
392+
output:
393+
resume_from: "output_ssv2_classifier/classifier_step_5000.safetensors"
394+
```
395+
396+
#### Weights & Biases Integration
397+
398+
Enable experiment tracking:
399+
400+
```bash
401+
python train_ssv2_classifier.py \
402+
--use-wandb \
403+
--wandb-project "my-ssv2-experiments" \
404+
--wandb-entity "my-team" \
405+
--wandb-run-name "vitl-bs8-lr1e3"
406+
```
407+
408+
Or in config:
409+
```yaml
410+
wandb:
411+
enabled: true
412+
project: "my-ssv2-experiments"
413+
entity: "my-team"
414+
run_name: "vitl-bs8-lr1e3"
415+
```
416+
417+
#### Adjusting Model Architecture
418+
419+
Customize classifier architecture in config:
420+
421+
```yaml
422+
model:
423+
num_probe_blocks: 1 # Classifier depth (1-3)
424+
num_heads: 16 # Attention heads (8, 12, 16)
425+
frames_per_clip: 16 # Temporal resolution (8, 16, 32)
426+
resolution: 224 # Spatial resolution (224, 256)
427+
```
428+
429+
### Expected Training Performance
430+
431+
On Apple Silicon (M2 Max, 64GB):
432+
- **Training speed**: ~2-3 samples/sec with batch size 4
433+
- **Memory usage**: ~8-12GB during training
434+
- **Full epoch time**: ~18-24 hours for full SSv2 dataset
435+
- **Validation accuracy**: ~40-50% top-1 after 10 epochs
436+
437+
### Output Files
438+
439+
Training produces:
440+
441+
```
442+
output_ssv2_classifier/
443+
├── training_YYYYMMDD_HHMMSS.log # Training log
444+
├── best_classifier.safetensors # Best model (highest val acc)
445+
├── classifier_step_1000.safetensors # Periodic checkpoints
446+
├── classifier_step_2000.safetensors
447+
└── training_history.json # Metrics history
448+
```
449+
450+
### Monitoring Training
451+
452+
#### View Real-time Logs
453+
454+
```bash
455+
tail -f output_ssv2_classifier/training_*.log
456+
```
457+
458+
#### Check Training History
459+
460+
```python
461+
import json
462+
with open('output_ssv2_classifier/training_history.json') as f:
463+
history = json.load(f)
464+
print(f"Best validation accuracy: {max(history['val_acc']):.2%}")
465+
```
466+
467+
#### Weights & Biases Dashboard
468+
469+
If using W&B, view metrics at: `https://wandb.ai/{entity}/{project}`
470+
471+
### Troubleshooting Fine-tuning
472+
473+
#### Low Validation Accuracy
474+
475+
- Increase `num_epochs` (try 20-30)
476+
- Adjust `learning_rate` (try 5e-4 or 2e-3)
477+
- Increase `batch_size` or gradient accumulation
478+
- Verify dataset integrity
479+
480+
#### Out of Memory
481+
482+
- Reduce `batch_size`
483+
- Enable gradient accumulation
484+
- Reduce `frames_per_clip` (try 8 or 12)
485+
- Lower `resolution` (try 192 or 200)
486+
487+
#### Slow Training
488+
489+
- Increase `batch_size` if memory allows
490+
- Reduce `frames_per_clip` for faster video loading
491+
- Use `--subset-size` for initial experiments
492+
- Close other memory-intensive applications
493+
494+
#### Training Crashes
495+
496+
- Check dataset paths are correct
497+
- Verify pretrained weights file exists
498+
- Ensure sufficient disk space for checkpoints
499+
- Check video files are not corrupted
500+
501+
### Next Steps
175502

176503
After installation:
177504

0 commit comments

Comments
 (0)