Skip to content

unsupervised domain adaptation (UDA) method support #3834

Open
@woldier

Description

@woldier

Describe the feature

Motivation
A clear and concise description of the motivation of the feature.
Ex1. It is inconvenient when [....].
Ex2. There is a recent paper [....], which is very helpful for [....].

I saw earlier that the author was seeking DAFormer authors to contribute to the repository.
However, as of now mmseg 1.x.x seems to have no support for UDA methods.
Based on this, I would like mmseg to support UDA methods.
Since the DAFormer repository's version of mmseg is too old, I referenced the repository's code and implemented it under mmseg version 1.2.2.
Since my direction is remote sensing, I first trained and tested some of the implemented UDA methods on remote sensing data (Potsdam, Vaihingen), and I found that they work fine under this version and achieve similar results as under the original version.

However, when I tried to train on the smart car dataset (GTA, CityScapes), I found that the performance varied particularly much from the original repository. Therefore, I'm looking for some help to finalize the new version of mmseg to support UDA methods.

Related resources
If there is an official code release or third-party implementations, please also provide the information here, which would be very helpful.

The following is my specification according to mmengine. Part of the implementation code for UDAdataset, datapreprocess, UDAEncoderDecoder

@DATASETS.register_module()
class UDADataset:
    """ A wrapper of UDA source and target dataset.
    The length of uda dataset will be source dataset 'times' target dataset.

    for idx, which means we need data ``idx%len(target)`` form source
    and ``idx//len(target)`` from target

    返回的数据格式为

    ···  dict(
    ···     inputs=inputs,
    ···     taget_inputs=taget_inputs,
    ···     data_samples=dict(
    ···         gt_sem_seg=gt_sem_seg, # source  label
    ···         metainfo=dict(
    ···                xxx=xxx # source meta 
    ···             ),
    ···     ),
    ···     {{self._target_prefix}}data_samples=dict(
    ···         gt_sem_seg=gt_sem_seg, # target  label
    ···         metainfo=dict(
    ···                xxx=xxx # target  meta 
    ···             ),
    ···     ),
    ···  )

    Args:
        source: source config 
        target: target config 
        lazy_init: 
        target_prefix: Used to set the prefix of the target data key.
            The default is tgt, so the target data loaded at this time will have a key that starts with tgt, such as tgt_input.
    """ # noqa

    def __init__(self, source, target, lazy_init=False, target_prefix='tgt', **kwargs):
        self._build_dataset(source=source, target=target)
        # Record meta information about the original dataset
        if not target_prefix.endswith('_'):
            target_prefix = target_prefix + '_'
        self._target_prefix = target_prefix
        self._metainfo = self.source.metainfo
        self._fully_initialized = False
        if not lazy_init:
            self.full_init()

    @property
    def target_prefix(self):
        return self._target_prefix

    @property
    def source(self) -> BaseDataset:
        return self._source # noqa

    @property
    def target(self) -> BaseDataset:
        return self._target # noqa

    def _tgt_key(self, key: str) -> str:
        return self._target_prefix + key

    def _build_dataset(self, **dataset_dict):
        for k, dataset in dataset_dict.items():
            # Build the source dataset (self.dataset)
            if isinstance(dataset, dict):
                _dataset = DATASETS.build(dataset)
                setattr(self, '_' + k, _dataset)
            elif isinstance(dataset, BaseDataset):
                setattr(self, '_' + k, dataset)
            else:
                raise TypeError(
                    'elements in datasets sequence should be config or '
                    f'`BaseDataset` instance, but got {type(dataset)}')

    def full_init(self):
        if self._fully_initialized:
            return

        # Initialize the source and target dataset completely
        self.source.full_init()
        self.target.full_init()

        self._fully_initialized = True

    @force_full_init
    def _get_ori_dataset_idx(self, idx: int):

        s, t = len(self.source), len(self.target)
        ori_idx_s = idx // t
        ori_idx_t = idx % t
        ori_idx = (ori_idx_s, ori_idx_t)
        return ori_idx

    #  Provides the same external interface as `self.dataset`.
    @force_full_init
    def get_data_info(self, idx):
        sample_idx_s, sample_idx_t = self._get_ori_dataset_idx(idx)
        return self.source.get_data_info(sample_idx_s)

    # Provides the same external interface as `__getitem__`.
    def __getitem__(self, idx):
        if not self._fully_initialized:
            warnings.warn('Please call `full_init` method manually to '
                          'accelerate the speed.')
            self.full_init()

        sample_idx_s, sample_idx_t = self._get_ori_dataset_idx(idx)
        out = self.source[sample_idx_s]
        out_t = self.target[sample_idx_t]
        tgt_inputs, tgt_data_samples = out_t["inputs"], out_t['data_samples']
        # To ensure that the GT label of the target data is not accessed incorrectly during the UDA process; the
        tgt_data_samples.pop('gt_sem_seg', None) # tgt_data_samples.gt_sem_seg  # other ops
        out[self._tgt_key("inputs")] = tgt_inputs
        out[self._tgt_key("data_samples")] = tgt_data_samples
        out["tgt_key_prefix"] = self.target_prefix  # Set tgt_key_prefix to make it easier to load the appropriate key during subsequent processing.
        return out

        # Provides the same external interface as `self.dataset`.

    @force_full_init
    def __len__(self):
        len_wrapper = len(self.source) * len(self.target)
        return len_wrapper

    # Provides the same external interface as `self.dataset`.
    @property
    def metainfo(self):
        return copy.deepcopy(self._metainfo)
from sys import prefix
from typing import Dict, Any

import mmengine

from .data_preprocessor import SegDataPreProcessor
from mmseg.structures import SegDataSample
from mmseg.registry import MODELS
from mmseg.utils import stack_batch
import torch


@MODELS.register_module()
class UDASegDataPreProcessor(SegDataPreProcessor):
    """UDASegDataPreProcessor
    This processor provides support for data preprocessing in UDA training.

    In UDA training, the target domain data is needed in addition to the source domain data.
    We agree that the key of the source domain data is the same as the key received by SegDataPreProcessor.
    And for target domain data, its key has prefix such as target_inputs target_data_samples.
    In this case, the naming of the target domain data key is in the same format as the naming of the source domain data, the only difference being the presence of a prefix.
    The definition of the data format is detailed in the Wrapper class UDADataset

    During the UDA testing process (val and test), the initialized dataset is the target domain dataset, not the Wrapper class UDADataset.
    Therefore, the behavior of the SegDataPreProcessor is preserved in this case.
    """

    def forward(self, data: dict, training: bool = False) -> Dict[str, Any]:
        """Perform normalization、padding and bgr2rgb conversion based on
               ``SegDataPreProcessor``.

        Args:
            data (dict): data sampled from dataloader.
            training (bool): Whether to enable training time augmentation.

        Returns:
            Dict: Data in the same format as the model input.
       """
        data = self.cast_data(data)  # type: ignore
        inputs = data['inputs']
        data_samples = data.get('data_samples', None)
        # # TODO: whether normalize should be after stack_batch
        # 为了在training中用公用操作处理source data , 此处将其注释了
        # 并且, 为了保持在非 training 情况下行为的一致性, 在else 中执行了这行代码
        # inputs = self._inputs_pre_process(inputs)

        if training:
         
            tgt_key_prefix = data.setdefault('tgt_key_prefix',None)
            assert tgt_key_prefix is not None, 'During UDA training, `tgt_key_prefix` must be define in dataset.'
            tgt_key_prefix = tgt_key_prefix[0] if isinstance(tgt_key_prefix, list) else tgt_key_prefix
            assert data_samples is not None, 'During training, `data_samples` must be define.'
            extra_fields = data.get('extra_fields', [])  
       
            if mmengine.is_list_of(extra_fields, expected_type=(list,tuple)): extra_fields = [i[0] for i in extra_fields]
            extra_fields.append('')  
            if tgt_key_prefix not in extra_fields: extra_fields.append(tgt_key_prefix)

            out= dict(tgt_key_prefix=tgt_key_prefix)
            for prefix_name in extra_fields:
                a_inputs = data[f"{prefix_name}inputs"]
                a_data_samples = data[f'{prefix_name}data_samples']
                # inputs pre-process
                a_inputs = self._inputs_pre_process(a_inputs)
                # stack samples
                a_inputs, a_data_samples = stack_batch(
                    inputs=a_inputs, data_samples=a_data_samples,
                    size=self.size, size_divisor=self.size_divisor,
                    pad_val=self.pad_val, seg_pad_val=self.seg_pad_val # type: ignore
                )
                if self.batch_augments is not None:
                    a_inputs, a_data_samples = self.batch_augments(a_inputs, a_data_samples)
                out[f'{prefix_name}inputs'] = a_inputs
                out[f'{prefix_name}data_samples'] = a_data_samples
            return out
        else:
            inputs = self._inputs_pre_process(inputs)  # norm 保证 training 和test 行为的一致性
            img_size = inputs[0].shape[1:]
            assert all(input_.shape[1:] == img_size for input_ in inputs), \
                'The image size in a batch should be the same.'
            # pad images when testing
            if self.test_cfg:
                inputs, padded_samples = stack_batch(
                    inputs=inputs,
                    size=self.test_cfg.get('size', None),
                    size_divisor=self.test_cfg.get('size_divisor', None),
                    pad_val=self.pad_val,
                    seg_pad_val=self.seg_pad_val)
                for data_sample, pad_info in zip(data_samples, padded_samples):
                    data_sample.set_metainfo({**pad_info})
            else:
                inputs = torch.stack(inputs, dim=0)

            return dict(inputs=inputs, data_samples=data_samples)

    def _inputs_pre_process(self, inputs):
        if self.channel_conversion and inputs[0].size(0) == 3:
            inputs = [_input[[2, 1, 0], ...] for _input in inputs]
        inputs = [_input.float() for _input in inputs]
        if self._enable_normalize:
            inputs = [(_input - self.mean) / self.std for _input in inputs]
        return inputs
class UDADecorator(BaseSegmentor):
    """UDADecorator
    This processor provides support for data preprocessing in UDA training.

    In UDA training, the target domain data is needed in addition to the source domain data.
    We agree that the key of the source domain data is the same as the key received by SegDataPreProcessor.
    And for target domain data, its key has prefix such as target_inputs target_data_samples.
    In this case, the naming of the target domain data key is in the same format as the naming of the source domain data, the only difference being the presence of a prefix.
    The definition of the data format is detailed in the Wrapper class UDADataset

    During the UDA testing process (val and test), the initialized dataset is the target domain dataset, not the Wrapper class UDADataset.
    Therefore, the behavior of the SegDataPreProcessor is preserved in this case.

    
    Parameters:
        model (dict): model config
        data_preprocessor (dict, optional): The pre-process config of :class:`.BaseDataPreprocessor`.
    """  # noqa: E501

    def __init__(self, model: ConfigType, data_preprocessor: OptConfigType = None,
                 work_dir: OPTStr = None, 
                 **cfg):
        super(BaseSegmentor, self).__init__(data_preprocessor=data_preprocessor)

        self.model = build_segmentor(deepcopy(model))
        self.align_corners = self.model.align_corners
        self.train_cfg = model['train_cfg']
        self.test_cfg = model['test_cfg']
        self.num_classes = model['decode_head']['num_classes']
        # Record the current number of iterations
        # During training, you can do something at a certain number of iterations.
        # Such as at a certain number of training steps, visualize the output.
        # Or some methods that require a Teacher, determine the current iteration count based on iter during training, and determine the weights for updating the EMA.
        self.register_buffer('local_iter', torch.tensor(0, dtype=torch.long))

    @property
    def iter(self):
        return self.local_iter

    def get_model(self) -> EncoderDecoder:
        return get_module(self.model)

    def extract_feat(self, inputs: Tensor) -> List[Tensor]:
        return self.get_model().extract_feat(inputs)

    def encode_decode(self, inputs: Tensor, batch_data_samples: List[dict]) -> Tensor:
        return self.get_model().encode_decode(inputs, batch_data_samples)

    def loss(self, inputs: Tensor, data_samples: SampleList) -> dict:
        return self.get_model().loss(inputs, data_samples)

    def predict(self, inputs: Tensor, data_samples: OptSampleList = None) -> SampleList:
        return self.get_model().predict(inputs, data_samples)

    def _forward(self, inputs: Tensor, data_samples: OptSampleList = None) -> Tensor:
        return self.get_model()._forward(inputs, data_samples)

    def train_step(self, data: Union[dict, tuple, list], optim_wrapper: OptimWrapper) -> Dict[str, torch.Tensor]:
        """
        Abstracts all UDA methods of the uda model. The current number of iterations is recorded.
        The training logic is done by _inner_train_step.

        Args:
            data:
            optim_wrapper:

        Returns:

        """
        data = self.data_preprocessor(data, True)  
        out = self._inner_train_step(data, optim_wrapper)
        self.local_iter += 1  # local iter acc
        return out

    def _inner_train_step(self, data: Union[dict, tuple, list], optim_wrapper: OptimWrapper) -> Dict[str, torch.Tensor]:
        """
       
        Args:
            data: 
            optim_wrapper: 

        Returns:

        Examples:

        >>> with optim_wrapper.optim_context(self.model):
        >>>     data = self.model.data_preprocessor(data, True)
        >>>     losses = self.model._run_forward(data, mode='loss')  # type: ignore
        >>> parsed_losses, log_vars = self.parse_losses(losses)  # type: ignore
        >>> optim_wrapper.update_params(parsed_losses)
        >>> return log_vars
        """
        raise NotImplementedError

    

Additional context
Add any other context or screenshots about the feature request here.
If you would like to implement the feature and create a PR, please leave a comment here and that would be much appreciated.

Metadata

Metadata

Assignees

No one assigned

    Labels

    No labels
    No labels

    Type

    No type

    Projects

    No projects

    Milestone

    No milestone

    Relationships

    None yet

    Development

    No branches or pull requests

    Issue actions