Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
57 changes: 57 additions & 0 deletions configs/selfsup/_base_/datasets/coco_orl.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,57 @@
import copy

# dataset settings
dataset_type = 'mmdet.CocoDataset'
# data_root = 'data/coco/'
data_root = '../data/coco/'
file_client_args = dict(backend='disk')
view_pipeline = [
dict(
type='RandomResizedCrop',
size=224,
interpolation='bicubic',
backend='pillow'),
dict(type='RandomFlip', prob=0.5),
dict(
type='RandomApply',
transforms=[
dict(
type='ColorJitter',
brightness=0.4,
contrast=0.4,
saturation=0.2,
hue=0.1)
],
prob=0.8),
dict(
type='RandomGrayscale',
prob=0.2,
keep_channels=True,
channel_weights=(0.114, 0.587, 0.2989)),
dict(type='RandomGaussianBlur', sigma_min=0.1, sigma_max=2.0, prob=1),
dict(type='RandomSolarize', prob=0)
]
view_pipeline1 = copy.deepcopy(view_pipeline)
view_pipeline2 = copy.deepcopy(view_pipeline)
view_pipeline2[4]['prob'] = 0.1 # gaussian blur
view_pipeline2[5]['prob'] = 0.2 # solarization
train_pipeline = [
dict(type='LoadImageFromFile', file_client_args=file_client_args),
dict(
type='MultiView',
num_views=[1, 1],
transforms=[view_pipeline1, view_pipeline2]),
dict(type='PackSelfSupInputs', meta_keys=['img_path'])
]
train_dataloader = dict(
batch_size=64,
num_workers=4,
persistent_workers=True,
sampler=dict(type='DefaultSampler', shuffle=True),
collate_fn=dict(type='default_collate'),
dataset=dict(
type=dataset_type,
data_root=data_root,
ann_file='annotations/instances_train2017.json',
data_prefix=dict(img='train2017/'),
pipeline=train_pipeline))
63 changes: 63 additions & 0 deletions configs/selfsup/orl/stage1/orl_resnet50_8xb64-coslr-800e_coco.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,63 @@
_base_ = [
'../../_base_/models/byol.py',
'../../_base_/datasets/coco_orl.py',
'../../_base_/schedules/sgd_coslr-200e_in1k.py',
'../../_base_/default_runtime.py',
]

# model settings
model = dict(
neck=dict(
type='NonLinearNeck',
in_channels=2048,
hid_channels=4096,
out_channels=256,
num_layers=2,
with_bias=False,
with_last_bn=False,
with_avg_pool=True),
head=dict(
type='LatentPredictHead',
predictor=dict(
type='NonLinearNeck',
in_channels=256,
hid_channels=4096,
out_channels=256,
num_layers=2,
with_bias=False,
with_last_bn=False,
with_avg_pool=False)))

update_interval = 1 # interval for accumulate gradient
# Amp optimizer
optimizer = dict(type='SGD', lr=0.4, weight_decay=0.0001, momentum=0.9)
optim_wrapper = dict(
type='AmpOptimWrapper',
optimizer=optimizer,
accumulative_counts=update_interval,
)
warmup_epochs = 4
total_epochs = 800
# learning policy
param_scheduler = [
# warmup
dict(
type='LinearLR',
start_factor=0.0001,
by_epoch=True,
end=warmup_epochs,
# Update the learning rate after every iters.
convert_to_iter_based=True),
# ConsineAnnealingLR/StepLR/..
dict(
type='CosineAnnealingLR',
eta_min=0.,
T_max=total_epochs,
by_epoch=True,
begin=warmup_epochs,
end=total_epochs)
]

# runtime settings
default_hooks = dict(checkpoint=dict(interval=100))
train_cfg = dict(type='EpochBasedTrainLoop', max_epochs=total_epochs)
Original file line number Diff line number Diff line change
@@ -0,0 +1,85 @@
_base_ = [
'../../_base_/models/byol.py',
'../../_base_/datasets/coco_orl.py',
'../../_base_/schedules/sgd_coslr-200e_in1k.py',
'../../_base_/default_runtime.py',
]
# model settings
model = dict(
neck=dict(
type='NonLinearNeck',
in_channels=2048,
hid_channels=4096,
out_channels=256,
num_layers=2,
with_bias=False,
with_last_bn=False,
with_avg_pool=True),
head=dict(
type='LatentPredictHead',
predictor=dict(
type='NonLinearNeck',
in_channels=256,
hid_channels=4096,
out_channels=256,
num_layers=2,
with_bias=False,
with_last_bn=False,
with_avg_pool=False)))

update_interval = 1 # interval for accumulate gradient
# Amp optimizer
optimizer = dict(type='SGD', lr=0.4, weight_decay=0.0001, momentum=0.9)
optim_wrapper = dict(
type='AmpOptimWrapper',
optimizer=optimizer,
accumulative_counts=update_interval,
)
warmup_epochs = 4
total_epochs = 5
# learning policy
param_scheduler = [
# warmup
dict(
type='LinearLR',
start_factor=0.0001,
by_epoch=True,
end=warmup_epochs,
# Update the learning rate after every iters.
convert_to_iter_based=True),
# ConsineAnnealingLR/StepLR/..
dict(
type='CosineAnnealingLR',
eta_min=0.,
T_max=total_epochs,
by_epoch=True,
begin=warmup_epochs,
end=total_epochs)
]

# "mmselfsup/configs/selfsup/orl/stage1/
# orl_resnet50_8xb64-coslr-800e_coco_extractor.py"
# runtime settings
default_hooks = dict(checkpoint=dict(interval=100))
train_cfg = dict(type='EpochBasedTrainLoop', max_epochs=total_epochs)
# load_from = './work_dirs/selfsup/orl/stage1/
# orl_resnet50_8xb64-coslr-800e_coco/epoch_100.pth'
# resume=True
custom_hooks = [
dict(
type='ExtractorHook',
keys=10,
extract_dataloader=dict(
batch_size=512,
num_workers=6,
persistent_workers=False,
sampler=dict(type='DefaultSampler', shuffle=False, round_up=True),
collate_fn=dict(type='default_collate'),
dataset=dict(
type={{_base_.dataset_type}},
data_root={{_base_.data_root}},
ann_file='annotations/instances_train2017.json',
data_prefix=dict(img='train2017/'),
pipeline={{_base_.train_pipeline}})),
normalize=True),
]
4 changes: 3 additions & 1 deletion mmselfsup/engine/hooks/__init__.py
Original file line number Diff line number Diff line change
@@ -1,10 +1,12 @@
# Copyright (c) OpenMMLab. All rights reserved.
from .deepcluster_hook import DeepClusterHook
from .densecl_hook import DenseCLHook
from .extractor_hook import ExtractorHook
from .odc_hook import ODCHook
from .simsiam_hook import SimSiamHook
from .swav_hook import SwAVHook

__all__ = [
'DeepClusterHook', 'DenseCLHook', 'ODCHook', 'SimSiamHook', 'SwAVHook'
'DeepClusterHook', 'DenseCLHook', 'ODCHook', 'SimSiamHook', 'SwAVHook',
'ExtractorHook'
]
1 change: 1 addition & 0 deletions mmselfsup/engine/hooks/deepcluster_hook.py
Original file line number Diff line number Diff line change
Expand Up @@ -75,6 +75,7 @@ def deepcluster(self, runner) -> None:
# step 1: get features
runner.model.eval()
features = self.extractor(runner.model.module)

runner.model.train()

# step 2: get labels
Expand Down
Loading