Skip to content
Snippets Groups Projects
training_tricks.md 2.2 KiB
Newer Older
Jerry Jiarui XU's avatar
Jerry Jiarui XU committed
# Tutorial 5: Training Tricks
Jiarui XU's avatar
Jiarui XU committed

MMSegmentation support following training tricks out of box.

## Different Learning Rate(LR) for Backbone and Heads

In semantic segmentation, some methods make the LR of heads larger than backbone to achieve better performance or faster convergence.

In MMSegmentation, you may add following lines to config to make the LR of heads 10 times of backbone.
Jiarui XU's avatar
Jiarui XU committed
```python
optimizer=dict(
Jiarui XU's avatar
Jiarui XU committed
    paramwise_cfg = dict(
        custom_keys={
            'head': dict(lr_mult=10.)}))
```
Jiarui XU's avatar
Jiarui XU committed
With this modification, the LR of any parameter group with `'head'` in name will be multiplied by 10.
You may refer to [MMCV doc](https://mmcv.readthedocs.io/en/latest/api.html#mmcv.runner.DefaultOptimizerConstructor) for further details.

## Online Hard Example Mining (OHEM)
Jiarui XU's avatar
Jiarui XU committed
We implement pixel sampler [here](https://github.com/open-mmlab/mmsegmentation/tree/master/mmseg/core/seg/sampler) for training sampling.
Here is an example config of training PSPNet with OHEM enabled.
Jiarui XU's avatar
Jiarui XU committed
```python
_base_ = './pspnet_r50-d8_512x1024_40k_cityscapes.py'
model=dict(
    decode_head=dict(
        sampler=dict(type='OHEMPixelSampler', thresh=0.7, min_kept=100000)) )
```
Jerry Jiarui XU's avatar
Jerry Jiarui XU committed
In this way, only pixels with confidence score under 0.7 are used to train. And we keep at least 100000 pixels during training. If `thresh` is not specified, pixels of top ``min_kept`` loss will be selected.

## Class Balanced Loss
For dataset that is not balanced in classes distribution, you may change the loss weight of each class.
Here is an example for cityscapes dataset.
```python
_base_ = './pspnet_r50-d8_512x1024_40k_cityscapes.py'
model=dict(
    decode_head=dict(
        loss_decode=dict(
            type='CrossEntropyLoss', use_sigmoid=False, loss_weight=1.0,
            # DeepLab used this class weight for cityscapes
            class_weight=[0.8373, 0.9180, 0.8660, 1.0345, 1.0166, 0.9969, 0.9754,
                        1.0489, 0.8786, 1.0023, 0.9539, 0.9843, 1.1116, 0.9037,
                        1.0865, 1.0955, 1.0865, 1.1529, 1.0507])))
```

`class_weight` will be passed into `CrossEntropyLoss` as `weight` argument. Please refer to [PyTorch Doc](https://pytorch.org/docs/stable/nn.html?highlight=crossentropy#torch.nn.CrossEntropyLoss) for details.