Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 1 addition & 1 deletion topomoe/src/models/common.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,7 +6,7 @@
from timm.layers import trunc_normal_
from timm.layers.helpers import to_2tuple, to_3tuple

from topomoe.utils import filter_kwargs
from topomoe.src.utils import filter_kwargs

State = Dict[str, torch.Tensor]
Layer = Callable[..., nn.Module]
Expand Down
56 changes: 56 additions & 0 deletions topomoe/src/models/topomoe.py
Original file line number Diff line number Diff line change
Expand Up @@ -504,3 +504,59 @@ def topomoe_tiny_3s_patch16_128(**kwargs):
}
model = model_factory(TopoMoETransformer, params, defaults, **kwargs)
return model




@register_model
def vit_tiny_patch16_128(**kwargs):
params = {
"img_size": 128,
"patch_size": 16,
"in_chans": 3,
"depths": (6,),
"widths": None,
"embed_dim": 384,
"num_experts": (1,),
}
defaults = {
"num_heads": 6,
}
model = model_factory(TopoMoETransformer, params, defaults, **kwargs)
return model


@register_model
def vit_small_patch16_128(**kwargs):
params = {
"img_size": 128,
"patch_size": 16,
"in_chans": 3,
"depths": (12,),
"widths": None,
"embed_dim": 384,
"num_experts": (1,),
}
defaults = {
"num_heads": 12,
}
model = model_factory(TopoMoETransformer, params, defaults, **kwargs)
return model


@register_model
def vit_base_patch16_128(**kwargs):
params = {
"img_size": 128,
"patch_size": 16,
"in_chans": 3,
"depths": (12,),
"widths": None,
"embed_dim": 768,
"num_experts": (1,),
}
defaults = {
"num_heads": 12,
}
model = model_factory(TopoMoETransformer, params, defaults, **kwargs)
return model
11 changes: 7 additions & 4 deletions topomoe/src/train.py
Original file line number Diff line number Diff line change
Expand Up @@ -28,9 +28,9 @@
from torch.utils.data import DataLoader
from transformers.hf_argparser import HfArg, HfArgumentParser

from topomoe import utils as ut
from topomoe.inspection import Figure, Metric, create_figures, create_metrics
from topomoe.models import create_model, list_models
from topomoe.src import utils as ut
from topomoe.src.inspection import Figure, Metric, create_figures, create_metrics
from topomoe.src.models import create_model, list_models

np.set_printoptions(precision=3)
plt.switch_backend("Agg")
Expand Down Expand Up @@ -116,7 +116,7 @@ class Args:
aliases=["--inmem"], default=False, help="keep dataset in memory"
)
# Optimization
epochs: int = HfArg(default=100, help="number of epochs")
epochs: int = HfArg(default=2, help="number of epochs")
batch_size: int = HfArg(
aliases=["--bs"], default=256, help="batch size per replica"
)
Expand Down Expand Up @@ -241,6 +241,7 @@ def main(args: Args):

# Dataset
logging.info("Loading dataset %s", args.dataset)
import pdb;pdb.set_trace()
dataset_train = create_dataset(
args.dataset,
root=args.data_dir,
Expand Down Expand Up @@ -283,6 +284,7 @@ def main(args: Args):
device=clust.device,
use_prefetcher=args.prefetch,
)
import pdb;pdb.set_trace()
loader_eval = create_loader(
dataset_eval,
input_size=input_size,
Expand Down Expand Up @@ -516,6 +518,7 @@ def train_one_epoch(
data_time = time.monotonic() - end

# forward pass
#import pdb;pdb.set_trace()
with autocast():
output, losses, state = model(input)
losses["class_loss"] = loss_fn(output, target)
Expand Down
136 changes: 134 additions & 2 deletions topomoe/tests/test_train.py
Original file line number Diff line number Diff line change
@@ -1,15 +1,39 @@
import pytest
import sys
sys.path.append('../')
from topomoe.src import train

from topomoe import train
# Replace with your path to Imagenet1k (ILSVRC2012) with train and val folders
imagenet_path = "../../datasets/ILSVRC2012/"

configs = {
"vit_small": train.Args(
name="debug_train_vit_small",
out_dir="topomoe/test_results",
model="vit_small_patch16_128",
dataset= "hfds/clane9/imagenet-100",
workers=1,
batch_size=1024,
overwrite=True,
debug=True,
),
"vit_base": train.Args(
name="debug_train_vit_small",
out_dir="topomoe/test_results",
model="vit_base_patch16_128",
dataset= "hfds/clane9/imagenet-100",
workers=1,
batch_size=1024,
overwrite=True,
debug=True,
),
"transformer": train.Args(
name="debug_train_transformer",
out_dir="topomoe/test_results",
model="quadmoe_tiny_1s_patch16_128",
dataset="hfds/clane9/imagenet-100",
workers=0,
batch_size=32,
batch_size=1024,
overwrite=True,
debug=True,
),
Expand Down Expand Up @@ -68,12 +92,118 @@
overwrite=True,
debug=True,
),
## Imagenet1k
"vit_small_imagenet1k": train.Args(
name="debug_train_vit_small",
out_dir="topomoe/test_results",
model="vit_small_patch16_128",
dataset= "imagenet1k",
data_dir = imagenet_path,
workers=1,
num_classes=1000,
batch_size=1024,
overwrite=True,
debug=True,
),
"vit_base_imagenet1k": train.Args(
name="debug_train_vit_small",
out_dir="topomoe/test_results",
model="vit_base_patch16_128",
dataset= "imagenet1k",
data_dir = imagenet_path,
workers=1,
num_classes=1000,
batch_size=1024,
overwrite=True,
debug=True,
),
"transformer_imagenet1k": train.Args(
name="debug_train_transformer",
out_dir="topomoe/test_results",
model="quadmoe_tiny_1s_patch16_128",
dataset="imagenet1k",
data_dir = imagenet_path,
workers=0,
num_classes=1000,
batch_size=1024,
overwrite=True,
debug=True,
),
"transformer_v2_imagenet1k": train.Args(
name="debug_train_transformer_v2",
out_dir="topomoe/test_results",
model="topomoe_tiny_1s_patch16_128",
dataset="imagenet1k",
data_dir = imagenet_path,
workers=0,
num_classes=1000,
batch_size=32,
overwrite=True,
debug=True,
),
"quadmoe_imagenet1k": train.Args(
name="debug_train_quadmoe",
out_dir="topomoe/test_results",
model="quadmoe_tiny_2s_patch16_128",
dataset="imagenet1k",
data_dir = imagenet_path,
workers=0,
num_classes=1000,
batch_size=32,
overwrite=True,
debug=True,
),
"softmoe_imagenet1k": train.Args(
name="debug_train_softmoe",
out_dir="topomoe/test_results",
model="softmoe_tiny_2s_patch16_128",
dataset="imagenet1k",
data_dir = imagenet_path,
workers=0,
num_classes=1000,
batch_size=32,
overwrite=True,
debug=True,
),
"topomoe_imagenet1k": train.Args(
name="debug_train_topomoe",
out_dir="topomoe/test_results",
model="topomoe_tiny_2s_patch16_128",
wiring_lambd=0.01,
dataset="imagenet1k",
data_dir = imagenet_path,
workers=0,
num_classes=1000,
batch_size=32,
overwrite=True,
debug=True,
),
"aug_imagenet1k": train.Args(
name="debug_train_aug",
out_dir="topomoe/test_results",
model="quadmoe_tiny_1s_patch16_128",
dataset="imagenet1k",
data_dir = imagenet_path,
num_classes=1000,
scale=[0.1, 0.3],
ratio=[1 / 4, 4 / 1],
hflip=0.5,
color_jitter=0.4,
workers=0,
batch_size=32,
overwrite=True,
debug=True,
),
}




@pytest.mark.parametrize(
"config",
[
"vit_small",
"vit_base",
"transformer",
"transformer_v2",
"quadmoe",
Expand All @@ -85,3 +215,5 @@
def test_train(config: str):
args = configs[config]
train.main(args)

test_train("vit_small_imagenet1k")