diff --git a/IN1K_prep/download_imagenet.sh b/IN1K_prep/download_imagenet.sh new file mode 100755 index 0000000..72c586d --- /dev/null +++ b/IN1K_prep/download_imagenet.sh @@ -0,0 +1,56 @@ +#!/bin/bash + +# This script downloads ImageNet-1k (ILSVRC2012) dataset +# You need to first register and get credentials from image-net.org + +# Replace these with your ImageNet credentials +USERNAME="user" +PASSWORD="pass" + +# Create directories +mkdir -p imagenet/train imagenet/validation + +# Download train and validation archives +# Training images (138GB) +wget --user=$USERNAME --password=$PASSWORD https://image-net.org/data/ILSVRC/2012/ILSVRC2012_img_train.tar + +# Validation images (6.3GB) +wget --user=$USERNAME --password=$PASSWORD https://image-net.org/data/ILSVRC/2012/ILSVRC2012_img_val.tar + +# Download mapping files +wget --user=$USERNAME --password=$PASSWORD https://image-net.org/data/ILSVRC/2012/ILSVRC2012_devkit_t12.tar.gz + +# Extract validation images +cd imagenet/validation +tar xvf ../../ILSVRC2012_img_val.tar +cd ../.. + +# Extract training images +cd imagenet/train +tar xvf ../../ILSVRC2012_img_train.tar + +# Extract individual class archives +for f in *.tar; do + d=`basename $f .tar` + mkdir -p $d + cd $d + tar xvf ../$f + cd .. + rm $f +done + +cd ../.. + +# Extract devkit +tar xvf ILSVRC2012_devkit_t12.tar.gz + +# Cleanup +rm ILSVRC2012_img_train.tar ILSVRC2012_img_val.tar ILSVRC2012_devkit_t12.tar.gz + +echo "Download and extraction complete!" + +# Optional: verify number of images +echo "Verifying image counts..." +echo "Number of training images: $(find imagenet/train -type f | wc -l)" +echo "Number of validation images: $(find imagenet/validation -type f | wc -l)" +echo "Now you should run process_imagenet_labels.py!" diff --git a/IN1K_prep/process_imagenet_labels.py b/IN1K_prep/process_imagenet_labels.py new file mode 100644 index 0000000..e22ccb9 --- /dev/null +++ b/IN1K_prep/process_imagenet_labels.py @@ -0,0 +1,181 @@ +#!/usr/bin/env python3 + +import os +import json +import shutil +import argparse +import scipy.io as sio +from pathlib import Path +from collections import defaultdict + + +def process_val_labels(devkit_path, val_path): + """ + Read validation ground truth and organize validation images into class folders + + Args: + devkit_path (Path): Path to ILSVRC2012_devkit_t12 directory + val_path (Path): Path to directory containing validation images + """ + print(f"Processing validation images in: {val_path}") + + # Read validation labels + val_labels_path = devkit_path / "data" / "ILSVRC2012_validation_ground_truth.txt" + with open(val_labels_path, "r") as f: + val_labels = [int(line.strip()) for line in f.readlines()] + + # Read meta.mat + meta = sio.loadmat(str(devkit_path / "data" / "meta.mat")) + synsets = meta["synsets"] + + # Create ILSVRC2012_ID -> WNID (synset) mapping + synset_mapping = {int(s[0][0][0][0]): str(s[0][1][0]) for s in synsets} + + # Create mapping of filename to class + val_filename_to_class = {} + for i, label in enumerate(val_labels, 1): + filename = f"ILSVRC2012_val_{i:08d}.JPEG" + val_filename_to_class[filename] = synset_mapping[label] + + # Create temporary directory for moving files + temp_dir = val_path / "temp_organization" + os.makedirs(temp_dir, exist_ok=True) + + # First, create all class directories in temp + for class_id in set(val_filename_to_class.values()): + os.makedirs(temp_dir / class_id, exist_ok=True) + + # Move files to their class directories in temp + print("Moving files to their class directories...") + moved_count = 0 + total_files = len(val_filename_to_class) + + for filename, class_id in val_filename_to_class.items(): + src_path = val_path / filename + dst_dir = temp_dir / class_id + dst_path = dst_dir / filename + + if src_path.exists(): + shutil.move(str(src_path), str(dst_path)) + moved_count += 1 + if moved_count % 1000 == 0: + print(f"Moved {moved_count}/{total_files} files...") + + # Remove any remaining files in val_path (except temp_organization) + for item in val_path.iterdir(): + if item.name != "temp_organization": + if item.is_file(): + item.unlink() + elif item.is_dir(): + shutil.rmtree(item) + + # Move everything from temp back to val_path + for class_dir in temp_dir.iterdir(): + if class_dir.is_dir(): + shutil.move(str(class_dir), str(val_path / class_dir.name)) + + # Remove temp directory + shutil.rmtree(temp_dir) + + print("\nOrganization complete!") + print( + f"Moved {moved_count} files into {len(set(val_filename_to_class.values()))} class directories" + ) + + +def create_class_info(devkit_path, val_path): + """ + Create a JSON file mapping synset IDs to human-readable labels and metadata + + Args: + devkit_path (Path): Path to ILSVRC2012_devkit_t12 directory + val_path (Path): Path to directory containing validation images + """ + class_info = defaultdict(dict) + + # Read meta.mat + meta = sio.loadmat(str(devkit_path / "data" / "meta.mat")) + synsets = meta["synsets"] + + # Extract information for each synset + for s in synsets: + synset_id = str(s[0][1][0]) # WNID + class_id = int(s[0][0][0][0]) # ILSVRC2012_ID + words = str(s[0][2][0]).split(", ") + gloss = str(s[0][3][0]) + + class_info[synset_id].update( + {"class_id": class_id, "words": words, "gloss": gloss} + ) + + # Count validation images per class + if val_path.exists(): + for class_dir in val_path.iterdir(): + if class_dir.is_dir(): + count = len(list(class_dir.glob("*.JPEG"))) + class_info[class_dir.name]["val_images"] = count + + # Save to JSON file + output_file = val_path.parent / "imagenet_class_info.json" + with open(output_file, "w") as f: + json.dump(class_info, f, indent=2, ensure_ascii=False) + + # Print example entries + print("\nExample class mappings:") + print("-" * 80) + for synset_id in list(class_info.keys())[:5]: + print(f"Synset ID: {synset_id}") + print(f"Class ID: {class_info[synset_id]['class_id']}") + print(f"Labels: {', '.join(class_info[synset_id]['words'])}") + print(f"Description: {class_info[synset_id]['gloss']}") + print(f"Validation images: {class_info[synset_id].get('val_images', 0)}") + print("-" * 80) + + +def main(): + parser = argparse.ArgumentParser( + description="Process ImageNet labels and create class mappings" + ) + parser.add_argument( + "devkit_path", type=str, help="Path to ILSVRC2012_devkit_t12 directory" + ) + parser.add_argument( + "val_path", type=str, help="Path to directory containing validation images" + ) + args = parser.parse_args() + + devkit_path = Path(args.devkit_path) + val_path = Path(args.val_path) + + # Verify paths and files + if not devkit_path.exists(): + raise FileNotFoundError(f"Devkit path does not exist: {devkit_path}") + if not val_path.exists(): + raise FileNotFoundError(f"Validation images path does not exist: {val_path}") + if not (devkit_path / "data" / "meta.mat").exists(): + raise FileNotFoundError(f"Could not find meta.mat in {devkit_path}/data") + if not (devkit_path / "data" / "ILSVRC2012_validation_ground_truth.txt").exists(): + raise FileNotFoundError( + f"Could not find validation ground truth file in {devkit_path}/data" + ) + + # Check if validation directory contains images + val_images = list(val_path.glob("ILSVRC2012_val_*.JPEG")) + if not val_images: + raise FileNotFoundError(f"No validation images found in {val_path}") + + print("Organizing validation images into class folders...") + process_val_labels(devkit_path, val_path) + + print("\nCreating class information JSON file...") + create_class_info(devkit_path, val_path) + + print( + f"\nDone! Check {val_path.parent}/imagenet_class_info.json for complete class mappings." + ) + + +if __name__ == "__main__": + main() + +# python3 process_imagenet_labels.py /weka/proj-medarc/shared/imagenet/ILSVRC2012_devkit_t12 /weka/proj-medarc/shared/imagenet/validation diff --git a/requirements.txt b/requirements.txt index 16d34bf..6a95df1 100644 --- a/requirements.txt +++ b/requirements.txt @@ -1,8 +1,11 @@ datasets==2.15.0 +einops==0.8.0 fvcore==0.1.5.post20221221 +h5py==3.12.1 matplotlib==3.8.2 numpy==1.26.2 Pillow==10.1.0 +scipy==1.14.1 timm==1.0.3 torch==2.3.0 torchvision==0.18.0 diff --git a/topomoe/src/topomoe/train.py b/topomoe/src/topomoe/train.py index 4adda2a..a50e0b0 100644 --- a/topomoe/src/topomoe/train.py +++ b/topomoe/src/topomoe/train.py @@ -11,6 +11,7 @@ from functools import partial from pathlib import Path from typing import Callable, Dict, List, Optional, Tuple +import PIL import numpy as np import torch @@ -81,13 +82,16 @@ class Args: aliases=["--wsigma"], default=2.0, help="wiring length radius stdev" ) # Dataset - dataset: str = HfArg( - default="hfds/clane9/imagenet-100", help="timm-compatible dataset name" + dataset: Optional[str] = HfArg( + default="hfds/clane9/imagenet-100", + help="Dataset name (timm compatible). If 'folder', uses data_dir which contains train and val folder of images.", ) - data_dir: Optional[str] = HfArg(default=None, help="dataset directory") - download: bool = HfArg(default=True, help="download dataset") + data_dir: Optional[str] = HfArg( + default=None, help="Dataset directory containing train/ and validation/" + ) + download: bool = HfArg(default=False, help="download dataset if needed") train_split: str = HfArg(default="train", help="name of training split") - val_split: str = HfArg(default="validation", help="name of val split") + val_split: str = HfArg(default="validation", help="name of validation split") train_num_samples: Optional[int] = HfArg( default=None, help="Manually specify num samples in train split, for IterableDatasets", @@ -241,9 +245,14 @@ def main(args: Args): # Dataset logging.info("Loading dataset %s", args.dataset) + if args.dataset == "folder": + root_dir = Path(args.data_dir) / args.train_split + args.train_split = "" + else: + root_dir = args.data_dir dataset_train = create_dataset( args.dataset, - root=args.data_dir, + root=root_dir, split=args.train_split, is_training=True, download=args.download, @@ -251,9 +260,12 @@ def main(args: Args): num_samples=args.train_num_samples, repeats=args.epoch_repeats, ) + if args.dataset == "folder": + root_dir = Path(args.data_dir) / args.val_split + args.val_split = "" dataset_eval = create_dataset( args.dataset, - root=args.data_dir, + root=root_dir, split=args.val_split, is_training=False, download=args.download, @@ -750,16 +762,23 @@ def validate( def load_dataset_in_memory(dataset: ImageDataset): - assert isinstance(dataset.reader, ReaderHfds) - dataset.reader.dataset = Dataset.from_dict( - dataset.reader.dataset.to_dict(), - features=dataset.reader.dataset.features, - ) + if isinstance(dataset.reader, ReaderHfds): + dataset.reader.dataset = Dataset.from_dict( + dataset.reader.dataset.to_dict(), + features=dataset.reader.dataset.features, + ) + else: + dataset.samples = [ + (PIL.Image.open(path).convert("RGB"), target) + for path, target in dataset.samples + ] def get_num_classes(dataset: ImageDataset): - assert isinstance(dataset.reader, ReaderHfds) - return dataset.reader.dataset.features["label"].num_classes + if isinstance(dataset.reader, ReaderHfds): + return dataset.reader.dataset.features["label"].num_classes + else: + return 1000 # for ImageNet1k @torch.no_grad() diff --git a/topomoe/tests/test_train.py b/topomoe/tests/test_train.py index 06e1c81..abc3216 100644 --- a/topomoe/tests/test_train.py +++ b/topomoe/tests/test_train.py @@ -64,10 +64,11 @@ ), "topomoe": train.Args( name="debug_train_topomoe", + data_dir="/weka/proj-medarc/shared/imagenet", out_dir="topomoe/test_results", model="topomoe_tiny_2s_patch16_128", wiring_lambd=0.01, - dataset="hfds/clane9/imagenet-100", + dataset="folder", workers=0, batch_size=32, overwrite=True, @@ -118,3 +119,29 @@ def test_train(config: str): args = configs[config] train.main(args) + + +if __name__ == "__main__": + import argparse + + parser = argparse.ArgumentParser(description="Run training configurations.") + parser.add_argument( + "--config", + nargs="+", + default=None, + help="List of configurations to run. If not specified, all configurations will be run.", + ) + + parsed_args = parser.parse_args() + + if parsed_args.config is not None: + configs_to_run = parsed_args.config + else: + configs_to_run = configs.keys() + + for config in configs_to_run: + if config in configs: + print(f"Running configuration: {config}") + test_train(config) + else: + print(f"Configuration '{config}' not found.")