diff --git a/topomoe/src/models/common.py b/topomoe/src/models/common.py index 5f922de..8170086 100644 --- a/topomoe/src/models/common.py +++ b/topomoe/src/models/common.py @@ -6,7 +6,7 @@ from timm.layers import trunc_normal_ from timm.layers.helpers import to_2tuple, to_3tuple -from topomoe.utils import filter_kwargs +from topomoe.src.utils import filter_kwargs State = Dict[str, torch.Tensor] Layer = Callable[..., nn.Module] diff --git a/topomoe/src/models/topomoe.py b/topomoe/src/models/topomoe.py index a229ed9..50a46e0 100644 --- a/topomoe/src/models/topomoe.py +++ b/topomoe/src/models/topomoe.py @@ -504,3 +504,59 @@ def topomoe_tiny_3s_patch16_128(**kwargs): } model = model_factory(TopoMoETransformer, params, defaults, **kwargs) return model + + + + +@register_model +def vit_tiny_patch16_128(**kwargs): + params = { + "img_size": 128, + "patch_size": 16, + "in_chans": 3, + "depths": (6,), + "widths": None, + "embed_dim": 384, + "num_experts": (1,), + } + defaults = { + "num_heads": 6, + } + model = model_factory(TopoMoETransformer, params, defaults, **kwargs) + return model + + +@register_model +def vit_small_patch16_128(**kwargs): + params = { + "img_size": 128, + "patch_size": 16, + "in_chans": 3, + "depths": (12,), + "widths": None, + "embed_dim": 384, + "num_experts": (1,), + } + defaults = { + "num_heads": 12, + } + model = model_factory(TopoMoETransformer, params, defaults, **kwargs) + return model + + +@register_model +def vit_base_patch16_128(**kwargs): + params = { + "img_size": 128, + "patch_size": 16, + "in_chans": 3, + "depths": (12,), + "widths": None, + "embed_dim": 768, + "num_experts": (1,), + } + defaults = { + "num_heads": 12, + } + model = model_factory(TopoMoETransformer, params, defaults, **kwargs) + return model \ No newline at end of file diff --git a/topomoe/src/train.py b/topomoe/src/train.py index 4adda2a..3924737 100644 --- a/topomoe/src/train.py +++ b/topomoe/src/train.py @@ -28,9 +28,9 @@ from torch.utils.data import DataLoader from transformers.hf_argparser import HfArg, HfArgumentParser -from topomoe import utils as ut -from topomoe.inspection import Figure, Metric, create_figures, create_metrics -from topomoe.models import create_model, list_models +from topomoe.src import utils as ut +from topomoe.src.inspection import Figure, Metric, create_figures, create_metrics +from topomoe.src.models import create_model, list_models np.set_printoptions(precision=3) plt.switch_backend("Agg") @@ -116,7 +116,7 @@ class Args: aliases=["--inmem"], default=False, help="keep dataset in memory" ) # Optimization - epochs: int = HfArg(default=100, help="number of epochs") + epochs: int = HfArg(default=2, help="number of epochs") batch_size: int = HfArg( aliases=["--bs"], default=256, help="batch size per replica" ) @@ -241,6 +241,7 @@ def main(args: Args): # Dataset logging.info("Loading dataset %s", args.dataset) + import pdb;pdb.set_trace() dataset_train = create_dataset( args.dataset, root=args.data_dir, @@ -283,6 +284,7 @@ def main(args: Args): device=clust.device, use_prefetcher=args.prefetch, ) + import pdb;pdb.set_trace() loader_eval = create_loader( dataset_eval, input_size=input_size, @@ -516,6 +518,7 @@ def train_one_epoch( data_time = time.monotonic() - end # forward pass + #import pdb;pdb.set_trace() with autocast(): output, losses, state = model(input) losses["class_loss"] = loss_fn(output, target) diff --git a/topomoe/tests/test_train.py b/topomoe/tests/test_train.py index b541d94..1bcd450 100644 --- a/topomoe/tests/test_train.py +++ b/topomoe/tests/test_train.py @@ -1,15 +1,39 @@ import pytest +import sys +sys.path.append('../') +from topomoe.src import train -from topomoe import train +# Replace with your path to Imagenet1k (ILSVRC2012) with train and val folders +imagenet_path = "../../datasets/ILSVRC2012/" configs = { + "vit_small": train.Args( + name="debug_train_vit_small", + out_dir="topomoe/test_results", + model="vit_small_patch16_128", + dataset= "hfds/clane9/imagenet-100", + workers=1, + batch_size=1024, + overwrite=True, + debug=True, + ), + "vit_base": train.Args( + name="debug_train_vit_small", + out_dir="topomoe/test_results", + model="vit_base_patch16_128", + dataset= "hfds/clane9/imagenet-100", + workers=1, + batch_size=1024, + overwrite=True, + debug=True, + ), "transformer": train.Args( name="debug_train_transformer", out_dir="topomoe/test_results", model="quadmoe_tiny_1s_patch16_128", dataset="hfds/clane9/imagenet-100", workers=0, - batch_size=32, + batch_size=1024, overwrite=True, debug=True, ), @@ -68,12 +92,118 @@ overwrite=True, debug=True, ), + ## Imagenet1k + "vit_small_imagenet1k": train.Args( + name="debug_train_vit_small", + out_dir="topomoe/test_results", + model="vit_small_patch16_128", + dataset= "imagenet1k", + data_dir = imagenet_path, + workers=1, + num_classes=1000, + batch_size=1024, + overwrite=True, + debug=True, + ), + "vit_base_imagenet1k": train.Args( + name="debug_train_vit_small", + out_dir="topomoe/test_results", + model="vit_base_patch16_128", + dataset= "imagenet1k", + data_dir = imagenet_path, + workers=1, + num_classes=1000, + batch_size=1024, + overwrite=True, + debug=True, + ), + "transformer_imagenet1k": train.Args( + name="debug_train_transformer", + out_dir="topomoe/test_results", + model="quadmoe_tiny_1s_patch16_128", + dataset="imagenet1k", + data_dir = imagenet_path, + workers=0, + num_classes=1000, + batch_size=1024, + overwrite=True, + debug=True, + ), + "transformer_v2_imagenet1k": train.Args( + name="debug_train_transformer_v2", + out_dir="topomoe/test_results", + model="topomoe_tiny_1s_patch16_128", + dataset="imagenet1k", + data_dir = imagenet_path, + workers=0, + num_classes=1000, + batch_size=32, + overwrite=True, + debug=True, + ), + "quadmoe_imagenet1k": train.Args( + name="debug_train_quadmoe", + out_dir="topomoe/test_results", + model="quadmoe_tiny_2s_patch16_128", + dataset="imagenet1k", + data_dir = imagenet_path, + workers=0, + num_classes=1000, + batch_size=32, + overwrite=True, + debug=True, + ), + "softmoe_imagenet1k": train.Args( + name="debug_train_softmoe", + out_dir="topomoe/test_results", + model="softmoe_tiny_2s_patch16_128", + dataset="imagenet1k", + data_dir = imagenet_path, + workers=0, + num_classes=1000, + batch_size=32, + overwrite=True, + debug=True, + ), + "topomoe_imagenet1k": train.Args( + name="debug_train_topomoe", + out_dir="topomoe/test_results", + model="topomoe_tiny_2s_patch16_128", + wiring_lambd=0.01, + dataset="imagenet1k", + data_dir = imagenet_path, + workers=0, + num_classes=1000, + batch_size=32, + overwrite=True, + debug=True, + ), + "aug_imagenet1k": train.Args( + name="debug_train_aug", + out_dir="topomoe/test_results", + model="quadmoe_tiny_1s_patch16_128", + dataset="imagenet1k", + data_dir = imagenet_path, + num_classes=1000, + scale=[0.1, 0.3], + ratio=[1 / 4, 4 / 1], + hflip=0.5, + color_jitter=0.4, + workers=0, + batch_size=32, + overwrite=True, + debug=True, + ), } + + @pytest.mark.parametrize( "config", [ + "vit_small", + "vit_base", "transformer", "transformer_v2", "quadmoe", @@ -85,3 +215,5 @@ def test_train(config: str): args = configs[config] train.main(args) + +test_train("vit_small_imagenet1k")