Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
56 changes: 56 additions & 0 deletions IN1K_prep/download_imagenet.sh
Original file line number Diff line number Diff line change
@@ -0,0 +1,56 @@
#!/bin/bash

# This script downloads ImageNet-1k (ILSVRC2012) dataset
# You need to first register and get credentials from image-net.org

# Replace these with your ImageNet credentials
USERNAME="user"
PASSWORD="pass"

# Create directories
mkdir -p imagenet/train imagenet/validation

# Download train and validation archives
# Training images (138GB)
wget --user=$USERNAME --password=$PASSWORD https://image-net.org/data/ILSVRC/2012/ILSVRC2012_img_train.tar

# Validation images (6.3GB)
wget --user=$USERNAME --password=$PASSWORD https://image-net.org/data/ILSVRC/2012/ILSVRC2012_img_val.tar

# Download mapping files
wget --user=$USERNAME --password=$PASSWORD https://image-net.org/data/ILSVRC/2012/ILSVRC2012_devkit_t12.tar.gz

# Extract validation images
cd imagenet/validation
tar xvf ../../ILSVRC2012_img_val.tar
cd ../..

# Extract training images
cd imagenet/train
tar xvf ../../ILSVRC2012_img_train.tar

# Extract individual class archives
for f in *.tar; do
d=`basename $f .tar`
mkdir -p $d
cd $d
tar xvf ../$f
cd ..
rm $f
done

cd ../..

# Extract devkit
tar xvf ILSVRC2012_devkit_t12.tar.gz

# Cleanup
rm ILSVRC2012_img_train.tar ILSVRC2012_img_val.tar ILSVRC2012_devkit_t12.tar.gz

echo "Download and extraction complete!"

# Optional: verify number of images
echo "Verifying image counts..."
echo "Number of training images: $(find imagenet/train -type f | wc -l)"
echo "Number of validation images: $(find imagenet/validation -type f | wc -l)"
echo "Now you should run process_imagenet_labels.py!"
181 changes: 181 additions & 0 deletions IN1K_prep/process_imagenet_labels.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,181 @@
#!/usr/bin/env python3

import os
import json
import shutil
import argparse
import scipy.io as sio
from pathlib import Path
from collections import defaultdict


def process_val_labels(devkit_path, val_path):
"""
Read validation ground truth and organize validation images into class folders

Args:
devkit_path (Path): Path to ILSVRC2012_devkit_t12 directory
val_path (Path): Path to directory containing validation images
"""
print(f"Processing validation images in: {val_path}")

# Read validation labels
val_labels_path = devkit_path / "data" / "ILSVRC2012_validation_ground_truth.txt"
with open(val_labels_path, "r") as f:
val_labels = [int(line.strip()) for line in f.readlines()]

# Read meta.mat
meta = sio.loadmat(str(devkit_path / "data" / "meta.mat"))
synsets = meta["synsets"]

# Create ILSVRC2012_ID -> WNID (synset) mapping
synset_mapping = {int(s[0][0][0][0]): str(s[0][1][0]) for s in synsets}

# Create mapping of filename to class
val_filename_to_class = {}
for i, label in enumerate(val_labels, 1):
filename = f"ILSVRC2012_val_{i:08d}.JPEG"
val_filename_to_class[filename] = synset_mapping[label]

# Create temporary directory for moving files
temp_dir = val_path / "temp_organization"
os.makedirs(temp_dir, exist_ok=True)

# First, create all class directories in temp
for class_id in set(val_filename_to_class.values()):
os.makedirs(temp_dir / class_id, exist_ok=True)

# Move files to their class directories in temp
print("Moving files to their class directories...")
moved_count = 0
total_files = len(val_filename_to_class)

for filename, class_id in val_filename_to_class.items():
src_path = val_path / filename
dst_dir = temp_dir / class_id
dst_path = dst_dir / filename

if src_path.exists():
shutil.move(str(src_path), str(dst_path))
moved_count += 1
if moved_count % 1000 == 0:
print(f"Moved {moved_count}/{total_files} files...")

# Remove any remaining files in val_path (except temp_organization)
for item in val_path.iterdir():
if item.name != "temp_organization":
if item.is_file():
item.unlink()
elif item.is_dir():
shutil.rmtree(item)

# Move everything from temp back to val_path
for class_dir in temp_dir.iterdir():
if class_dir.is_dir():
shutil.move(str(class_dir), str(val_path / class_dir.name))

# Remove temp directory
shutil.rmtree(temp_dir)

print("\nOrganization complete!")
print(
f"Moved {moved_count} files into {len(set(val_filename_to_class.values()))} class directories"
)


def create_class_info(devkit_path, val_path):
"""
Create a JSON file mapping synset IDs to human-readable labels and metadata

Args:
devkit_path (Path): Path to ILSVRC2012_devkit_t12 directory
val_path (Path): Path to directory containing validation images
"""
class_info = defaultdict(dict)

# Read meta.mat
meta = sio.loadmat(str(devkit_path / "data" / "meta.mat"))
synsets = meta["synsets"]

# Extract information for each synset
for s in synsets:
synset_id = str(s[0][1][0]) # WNID
class_id = int(s[0][0][0][0]) # ILSVRC2012_ID
words = str(s[0][2][0]).split(", ")
gloss = str(s[0][3][0])

class_info[synset_id].update(
{"class_id": class_id, "words": words, "gloss": gloss}
)

# Count validation images per class
if val_path.exists():
for class_dir in val_path.iterdir():
if class_dir.is_dir():
count = len(list(class_dir.glob("*.JPEG")))
class_info[class_dir.name]["val_images"] = count

# Save to JSON file
output_file = val_path.parent / "imagenet_class_info.json"
with open(output_file, "w") as f:
json.dump(class_info, f, indent=2, ensure_ascii=False)

# Print example entries
print("\nExample class mappings:")
print("-" * 80)
for synset_id in list(class_info.keys())[:5]:
print(f"Synset ID: {synset_id}")
print(f"Class ID: {class_info[synset_id]['class_id']}")
print(f"Labels: {', '.join(class_info[synset_id]['words'])}")
print(f"Description: {class_info[synset_id]['gloss']}")
print(f"Validation images: {class_info[synset_id].get('val_images', 0)}")
print("-" * 80)


def main():
parser = argparse.ArgumentParser(
description="Process ImageNet labels and create class mappings"
)
parser.add_argument(
"devkit_path", type=str, help="Path to ILSVRC2012_devkit_t12 directory"
)
parser.add_argument(
"val_path", type=str, help="Path to directory containing validation images"
)
args = parser.parse_args()

devkit_path = Path(args.devkit_path)
val_path = Path(args.val_path)

# Verify paths and files
if not devkit_path.exists():
raise FileNotFoundError(f"Devkit path does not exist: {devkit_path}")
if not val_path.exists():
raise FileNotFoundError(f"Validation images path does not exist: {val_path}")
if not (devkit_path / "data" / "meta.mat").exists():
raise FileNotFoundError(f"Could not find meta.mat in {devkit_path}/data")
if not (devkit_path / "data" / "ILSVRC2012_validation_ground_truth.txt").exists():
raise FileNotFoundError(
f"Could not find validation ground truth file in {devkit_path}/data"
)

# Check if validation directory contains images
val_images = list(val_path.glob("ILSVRC2012_val_*.JPEG"))
if not val_images:
raise FileNotFoundError(f"No validation images found in {val_path}")

print("Organizing validation images into class folders...")
process_val_labels(devkit_path, val_path)

print("\nCreating class information JSON file...")
create_class_info(devkit_path, val_path)

print(
f"\nDone! Check {val_path.parent}/imagenet_class_info.json for complete class mappings."
)


if __name__ == "__main__":
main()

# python3 process_imagenet_labels.py /weka/proj-medarc/shared/imagenet/ILSVRC2012_devkit_t12 /weka/proj-medarc/shared/imagenet/validation
3 changes: 3 additions & 0 deletions requirements.txt
Original file line number Diff line number Diff line change
@@ -1,8 +1,11 @@
datasets==2.15.0
einops==0.8.0
fvcore==0.1.5.post20221221
h5py==3.12.1
matplotlib==3.8.2
numpy==1.26.2
Pillow==10.1.0
scipy==1.14.1
timm==1.0.3
torch==2.3.0
torchvision==0.18.0
Expand Down
47 changes: 33 additions & 14 deletions topomoe/src/topomoe/train.py
Original file line number Diff line number Diff line change
Expand Up @@ -11,6 +11,7 @@
from functools import partial
from pathlib import Path
from typing import Callable, Dict, List, Optional, Tuple
import PIL

import numpy as np
import torch
Expand Down Expand Up @@ -81,13 +82,16 @@ class Args:
aliases=["--wsigma"], default=2.0, help="wiring length radius stdev"
)
# Dataset
dataset: str = HfArg(
default="hfds/clane9/imagenet-100", help="timm-compatible dataset name"
dataset: Optional[str] = HfArg(
default="hfds/clane9/imagenet-100",
help="Dataset name (timm compatible). If 'folder', uses data_dir which contains train and val folder of images.",
)
data_dir: Optional[str] = HfArg(default=None, help="dataset directory")
download: bool = HfArg(default=True, help="download dataset")
data_dir: Optional[str] = HfArg(
default=None, help="Dataset directory containing train/ and validation/"
)
download: bool = HfArg(default=False, help="download dataset if needed")
train_split: str = HfArg(default="train", help="name of training split")
val_split: str = HfArg(default="validation", help="name of val split")
val_split: str = HfArg(default="validation", help="name of validation split")
train_num_samples: Optional[int] = HfArg(
default=None,
help="Manually specify num samples in train split, for IterableDatasets",
Expand Down Expand Up @@ -241,19 +245,27 @@ def main(args: Args):

# Dataset
logging.info("Loading dataset %s", args.dataset)
if args.dataset == "folder":
root_dir = Path(args.data_dir) / args.train_split
args.train_split = ""
else:
root_dir = args.data_dir
dataset_train = create_dataset(
args.dataset,
root=args.data_dir,
root=root_dir,
split=args.train_split,
is_training=True,
download=args.download,
batch_size=args.batch_size,
num_samples=args.train_num_samples,
repeats=args.epoch_repeats,
)
if args.dataset == "folder":
root_dir = Path(args.data_dir) / args.val_split
args.val_split = ""
dataset_eval = create_dataset(
args.dataset,
root=args.data_dir,
root=root_dir,
split=args.val_split,
is_training=False,
download=args.download,
Expand Down Expand Up @@ -750,16 +762,23 @@ def validate(


def load_dataset_in_memory(dataset: ImageDataset):
assert isinstance(dataset.reader, ReaderHfds)
dataset.reader.dataset = Dataset.from_dict(
dataset.reader.dataset.to_dict(),
features=dataset.reader.dataset.features,
)
if isinstance(dataset.reader, ReaderHfds):
dataset.reader.dataset = Dataset.from_dict(
dataset.reader.dataset.to_dict(),
features=dataset.reader.dataset.features,
)
else:
dataset.samples = [
(PIL.Image.open(path).convert("RGB"), target)
for path, target in dataset.samples
]


def get_num_classes(dataset: ImageDataset):
assert isinstance(dataset.reader, ReaderHfds)
return dataset.reader.dataset.features["label"].num_classes
if isinstance(dataset.reader, ReaderHfds):
return dataset.reader.dataset.features["label"].num_classes
else:
return 1000 # for ImageNet1k


@torch.no_grad()
Expand Down
29 changes: 28 additions & 1 deletion topomoe/tests/test_train.py
Original file line number Diff line number Diff line change
Expand Up @@ -64,10 +64,11 @@
),
"topomoe": train.Args(
name="debug_train_topomoe",
data_dir="/weka/proj-medarc/shared/imagenet",
out_dir="topomoe/test_results",
model="topomoe_tiny_2s_patch16_128",
wiring_lambd=0.01,
dataset="hfds/clane9/imagenet-100",
dataset="folder",
workers=0,
batch_size=32,
overwrite=True,
Expand Down Expand Up @@ -118,3 +119,29 @@
def test_train(config: str):
args = configs[config]
train.main(args)


if __name__ == "__main__":
import argparse

parser = argparse.ArgumentParser(description="Run training configurations.")
parser.add_argument(
"--config",
nargs="+",
default=None,
help="List of configurations to run. If not specified, all configurations will be run.",
)

parsed_args = parser.parse_args()

if parsed_args.config is not None:
configs_to_run = parsed_args.config
else:
configs_to_run = configs.keys()

for config in configs_to_run:
if config in configs:
print(f"Running configuration: {config}")
test_train(config)
else:
print(f"Configuration '{config}' not found.")