Description
Describe the feature
Motivation
A clear and concise description of the motivation of the feature.
Ex1. It is inconvenient when [....].
Ex2. There is a recent paper [....], which is very helpful for [....].
I saw earlier that the author was seeking DAFormer
authors to contribute to the repository.
However, as of now mmseg 1.x.x seems to have no support for UDA methods.
Based on this, I would like mmseg to support UDA methods.
Since the DAFormer
repository's version of mmseg is too old, I referenced the repository's code and implemented it under mmseg version 1.2.2.
Since my direction is remote sensing, I first trained and tested some of the implemented UDA methods on remote sensing data (Potsdam, Vaihingen), and I found that they work fine under this version and achieve similar results as under the original version.
However, when I tried to train on the smart car dataset (GTA, CityScapes), I found that the performance varied particularly much from the original repository. Therefore, I'm looking for some help to finalize the new version of mmseg to support UDA methods.
Related resources
If there is an official code release or third-party implementations, please also provide the information here, which would be very helpful.
The following is my specification according to mmengine
. Part of the implementation code for UDAdataset
, datapreprocess
, UDAEncoderDecoder
@DATASETS.register_module()
class UDADataset:
""" A wrapper of UDA source and target dataset.
The length of uda dataset will be source dataset 'times' target dataset.
for idx, which means we need data ``idx%len(target)`` form source
and ``idx//len(target)`` from target
返回的数据格式为
··· dict(
··· inputs=inputs,
··· taget_inputs=taget_inputs,
··· data_samples=dict(
··· gt_sem_seg=gt_sem_seg, # source label
··· metainfo=dict(
··· xxx=xxx # source meta
··· ),
··· ),
··· {{self._target_prefix}}data_samples=dict(
··· gt_sem_seg=gt_sem_seg, # target label
··· metainfo=dict(
··· xxx=xxx # target meta
··· ),
··· ),
··· )
Args:
source: source config
target: target config
lazy_init:
target_prefix: Used to set the prefix of the target data key.
The default is tgt, so the target data loaded at this time will have a key that starts with tgt, such as tgt_input.
""" # noqa
def __init__(self, source, target, lazy_init=False, target_prefix='tgt', **kwargs):
self._build_dataset(source=source, target=target)
# Record meta information about the original dataset
if not target_prefix.endswith('_'):
target_prefix = target_prefix + '_'
self._target_prefix = target_prefix
self._metainfo = self.source.metainfo
self._fully_initialized = False
if not lazy_init:
self.full_init()
@property
def target_prefix(self):
return self._target_prefix
@property
def source(self) -> BaseDataset:
return self._source # noqa
@property
def target(self) -> BaseDataset:
return self._target # noqa
def _tgt_key(self, key: str) -> str:
return self._target_prefix + key
def _build_dataset(self, **dataset_dict):
for k, dataset in dataset_dict.items():
# Build the source dataset (self.dataset)
if isinstance(dataset, dict):
_dataset = DATASETS.build(dataset)
setattr(self, '_' + k, _dataset)
elif isinstance(dataset, BaseDataset):
setattr(self, '_' + k, dataset)
else:
raise TypeError(
'elements in datasets sequence should be config or '
f'`BaseDataset` instance, but got {type(dataset)}')
def full_init(self):
if self._fully_initialized:
return
# Initialize the source and target dataset completely
self.source.full_init()
self.target.full_init()
self._fully_initialized = True
@force_full_init
def _get_ori_dataset_idx(self, idx: int):
s, t = len(self.source), len(self.target)
ori_idx_s = idx // t
ori_idx_t = idx % t
ori_idx = (ori_idx_s, ori_idx_t)
return ori_idx
# Provides the same external interface as `self.dataset`.
@force_full_init
def get_data_info(self, idx):
sample_idx_s, sample_idx_t = self._get_ori_dataset_idx(idx)
return self.source.get_data_info(sample_idx_s)
# Provides the same external interface as `__getitem__`.
def __getitem__(self, idx):
if not self._fully_initialized:
warnings.warn('Please call `full_init` method manually to '
'accelerate the speed.')
self.full_init()
sample_idx_s, sample_idx_t = self._get_ori_dataset_idx(idx)
out = self.source[sample_idx_s]
out_t = self.target[sample_idx_t]
tgt_inputs, tgt_data_samples = out_t["inputs"], out_t['data_samples']
# To ensure that the GT label of the target data is not accessed incorrectly during the UDA process; the
tgt_data_samples.pop('gt_sem_seg', None) # tgt_data_samples.gt_sem_seg # other ops
out[self._tgt_key("inputs")] = tgt_inputs
out[self._tgt_key("data_samples")] = tgt_data_samples
out["tgt_key_prefix"] = self.target_prefix # Set tgt_key_prefix to make it easier to load the appropriate key during subsequent processing.
return out
# Provides the same external interface as `self.dataset`.
@force_full_init
def __len__(self):
len_wrapper = len(self.source) * len(self.target)
return len_wrapper
# Provides the same external interface as `self.dataset`.
@property
def metainfo(self):
return copy.deepcopy(self._metainfo)
from sys import prefix
from typing import Dict, Any
import mmengine
from .data_preprocessor import SegDataPreProcessor
from mmseg.structures import SegDataSample
from mmseg.registry import MODELS
from mmseg.utils import stack_batch
import torch
@MODELS.register_module()
class UDASegDataPreProcessor(SegDataPreProcessor):
"""UDASegDataPreProcessor
This processor provides support for data preprocessing in UDA training.
In UDA training, the target domain data is needed in addition to the source domain data.
We agree that the key of the source domain data is the same as the key received by SegDataPreProcessor.
And for target domain data, its key has prefix such as target_inputs target_data_samples.
In this case, the naming of the target domain data key is in the same format as the naming of the source domain data, the only difference being the presence of a prefix.
The definition of the data format is detailed in the Wrapper class UDADataset
During the UDA testing process (val and test), the initialized dataset is the target domain dataset, not the Wrapper class UDADataset.
Therefore, the behavior of the SegDataPreProcessor is preserved in this case.
"""
def forward(self, data: dict, training: bool = False) -> Dict[str, Any]:
"""Perform normalization、padding and bgr2rgb conversion based on
``SegDataPreProcessor``.
Args:
data (dict): data sampled from dataloader.
training (bool): Whether to enable training time augmentation.
Returns:
Dict: Data in the same format as the model input.
"""
data = self.cast_data(data) # type: ignore
inputs = data['inputs']
data_samples = data.get('data_samples', None)
# # TODO: whether normalize should be after stack_batch
# 为了在training中用公用操作处理source data , 此处将其注释了
# 并且, 为了保持在非 training 情况下行为的一致性, 在else 中执行了这行代码
# inputs = self._inputs_pre_process(inputs)
if training:
tgt_key_prefix = data.setdefault('tgt_key_prefix',None)
assert tgt_key_prefix is not None, 'During UDA training, `tgt_key_prefix` must be define in dataset.'
tgt_key_prefix = tgt_key_prefix[0] if isinstance(tgt_key_prefix, list) else tgt_key_prefix
assert data_samples is not None, 'During training, `data_samples` must be define.'
extra_fields = data.get('extra_fields', [])
if mmengine.is_list_of(extra_fields, expected_type=(list,tuple)): extra_fields = [i[0] for i in extra_fields]
extra_fields.append('')
if tgt_key_prefix not in extra_fields: extra_fields.append(tgt_key_prefix)
out= dict(tgt_key_prefix=tgt_key_prefix)
for prefix_name in extra_fields:
a_inputs = data[f"{prefix_name}inputs"]
a_data_samples = data[f'{prefix_name}data_samples']
# inputs pre-process
a_inputs = self._inputs_pre_process(a_inputs)
# stack samples
a_inputs, a_data_samples = stack_batch(
inputs=a_inputs, data_samples=a_data_samples,
size=self.size, size_divisor=self.size_divisor,
pad_val=self.pad_val, seg_pad_val=self.seg_pad_val # type: ignore
)
if self.batch_augments is not None:
a_inputs, a_data_samples = self.batch_augments(a_inputs, a_data_samples)
out[f'{prefix_name}inputs'] = a_inputs
out[f'{prefix_name}data_samples'] = a_data_samples
return out
else:
inputs = self._inputs_pre_process(inputs) # norm 保证 training 和test 行为的一致性
img_size = inputs[0].shape[1:]
assert all(input_.shape[1:] == img_size for input_ in inputs), \
'The image size in a batch should be the same.'
# pad images when testing
if self.test_cfg:
inputs, padded_samples = stack_batch(
inputs=inputs,
size=self.test_cfg.get('size', None),
size_divisor=self.test_cfg.get('size_divisor', None),
pad_val=self.pad_val,
seg_pad_val=self.seg_pad_val)
for data_sample, pad_info in zip(data_samples, padded_samples):
data_sample.set_metainfo({**pad_info})
else:
inputs = torch.stack(inputs, dim=0)
return dict(inputs=inputs, data_samples=data_samples)
def _inputs_pre_process(self, inputs):
if self.channel_conversion and inputs[0].size(0) == 3:
inputs = [_input[[2, 1, 0], ...] for _input in inputs]
inputs = [_input.float() for _input in inputs]
if self._enable_normalize:
inputs = [(_input - self.mean) / self.std for _input in inputs]
return inputs
class UDADecorator(BaseSegmentor):
"""UDADecorator
This processor provides support for data preprocessing in UDA training.
In UDA training, the target domain data is needed in addition to the source domain data.
We agree that the key of the source domain data is the same as the key received by SegDataPreProcessor.
And for target domain data, its key has prefix such as target_inputs target_data_samples.
In this case, the naming of the target domain data key is in the same format as the naming of the source domain data, the only difference being the presence of a prefix.
The definition of the data format is detailed in the Wrapper class UDADataset
During the UDA testing process (val and test), the initialized dataset is the target domain dataset, not the Wrapper class UDADataset.
Therefore, the behavior of the SegDataPreProcessor is preserved in this case.
Parameters:
model (dict): model config
data_preprocessor (dict, optional): The pre-process config of :class:`.BaseDataPreprocessor`.
""" # noqa: E501
def __init__(self, model: ConfigType, data_preprocessor: OptConfigType = None,
work_dir: OPTStr = None,
**cfg):
super(BaseSegmentor, self).__init__(data_preprocessor=data_preprocessor)
self.model = build_segmentor(deepcopy(model))
self.align_corners = self.model.align_corners
self.train_cfg = model['train_cfg']
self.test_cfg = model['test_cfg']
self.num_classes = model['decode_head']['num_classes']
# Record the current number of iterations
# During training, you can do something at a certain number of iterations.
# Such as at a certain number of training steps, visualize the output.
# Or some methods that require a Teacher, determine the current iteration count based on iter during training, and determine the weights for updating the EMA.
self.register_buffer('local_iter', torch.tensor(0, dtype=torch.long))
@property
def iter(self):
return self.local_iter
def get_model(self) -> EncoderDecoder:
return get_module(self.model)
def extract_feat(self, inputs: Tensor) -> List[Tensor]:
return self.get_model().extract_feat(inputs)
def encode_decode(self, inputs: Tensor, batch_data_samples: List[dict]) -> Tensor:
return self.get_model().encode_decode(inputs, batch_data_samples)
def loss(self, inputs: Tensor, data_samples: SampleList) -> dict:
return self.get_model().loss(inputs, data_samples)
def predict(self, inputs: Tensor, data_samples: OptSampleList = None) -> SampleList:
return self.get_model().predict(inputs, data_samples)
def _forward(self, inputs: Tensor, data_samples: OptSampleList = None) -> Tensor:
return self.get_model()._forward(inputs, data_samples)
def train_step(self, data: Union[dict, tuple, list], optim_wrapper: OptimWrapper) -> Dict[str, torch.Tensor]:
"""
Abstracts all UDA methods of the uda model. The current number of iterations is recorded.
The training logic is done by _inner_train_step.
Args:
data:
optim_wrapper:
Returns:
"""
data = self.data_preprocessor(data, True)
out = self._inner_train_step(data, optim_wrapper)
self.local_iter += 1 # local iter acc
return out
def _inner_train_step(self, data: Union[dict, tuple, list], optim_wrapper: OptimWrapper) -> Dict[str, torch.Tensor]:
"""
Args:
data:
optim_wrapper:
Returns:
Examples:
>>> with optim_wrapper.optim_context(self.model):
>>> data = self.model.data_preprocessor(data, True)
>>> losses = self.model._run_forward(data, mode='loss') # type: ignore
>>> parsed_losses, log_vars = self.parse_losses(losses) # type: ignore
>>> optim_wrapper.update_params(parsed_losses)
>>> return log_vars
"""
raise NotImplementedError
Additional context
Add any other context or screenshots about the feature request here.
If you would like to implement the feature and create a PR, please leave a comment here and that would be much appreciated.