Skip to content
Snippets Groups Projects
Commit ebf3c084 authored by 谢昕辰's avatar 谢昕辰 Committed by GitHub
Browse files

[Tools] Add vit/swin/mit convert weight scripts (#783)

* init scripts

* update markdown

* update markdown

* add docs

* delete mit converter and use torch load function

* rename segformer readme

* update doc

* modify doc

* 更新中文文档

* Update useful_tools.md

* Update useful_tools.md

* modify doc

* update segformer.yml
parent 2fe0bddf
No related branches found
No related tags found
No related merge requests found
......@@ -29,15 +29,15 @@
Evaluation with AlignedResize:
| Method | Backbone | Crop Size | Lr schd | mIoU | mIoU(ms+flip) |
| ------ | -------- | --------- | ------: | ---: | ------------- |
|Segformer | MIT-B0 | 512x512 | 160000 | 38.1 | 38.57 |
|Segformer | MIT-B1 | 512x512 | 160000 | 41.64 | 42.76 |
|Segformer | MIT-B2 | 512x512 | 160000 | 46.53 | 47.49 |
|Segformer | MIT-B3 | 512x512 | 160000 | 48.46 | 49.14 |
|Segformer | MIT-B4 | 512x512 | 160000 | 49.34 | 50.29 |
|Segformer | MIT-B5 | 512x512 | 160000 | 50.08 | 50.72 |
|Segformer | MIT-B5 | 640x640 | 160000 | 50.58 | 50.8 |
| Method | Backbone | Crop Size | Lr schd | mIoU | mIoU(ms+flip) |
| ------ | -------- | --------- | ------: | ---: | ------------- |
|Segformer | MIT-B0 | 512x512 | 160000 | 38.1 | 38.57 |
|Segformer | MIT-B1 | 512x512 | 160000 | 41.64 | 42.76 |
|Segformer | MIT-B2 | 512x512 | 160000 | 46.53 | 47.49 |
|Segformer | MIT-B3 | 512x512 | 160000 | 48.46 | 49.14 |
|Segformer | MIT-B4 | 512x512 | 160000 | 49.34 | 50.29 |
|Segformer | MIT-B5 | 512x512 | 160000 | 50.08 | 50.72 |
|Segformer | MIT-B5 | 640x640 | 160000 | 50.58 | 50.8 |
We replace `AlignedResize` in original implementatiuon to `Resize + ResizeToMultiple`. If you want to test by
using `AlignedResize`, you can change the dataset pipeline like this:
......
Collections:
- Metadata:
Training Data:
- ADE20k
Name: segformer
Models:
- Config: configs/segformer/segformer_mit-b0_512x512_160k_ade20k.py
In Collection: segformer
Metadata:
backbone: MIT-B0
crop size: (512,512)
inference time (ms/im):
- backend: PyTorch
batch size: 1
hardware: V100
mode: FP32
resolution: (512,512)
value: 19.49
lr schd: 160000
memory (GB): 2.1
Name: segformer_mit-b0_512x512_160k_ade20k
Results:
Dataset: ADE20k
Metrics:
mIoU: 37.41
mIoU(ms+flip): 38.34
Task: Semantic Segmentation
Weights: https://download.openmmlab.com/mmsegmentation/v0.5/segformer/segformer_mit-b0_512x512_160k_ade20k/segformer_mit-b0_512x512_160k_ade20k_20210726_101530-8ffa8fda.pth
- Config: configs/segformer/segformer_mit-b1_512x512_160k_ade20k.py
In Collection: segformer
Metadata:
backbone: MIT-B1
crop size: (512,512)
inference time (ms/im):
- backend: PyTorch
batch size: 1
hardware: V100
mode: FP32
resolution: (512,512)
value: 20.98
lr schd: 160000
memory (GB): 2.6
Name: segformer_mit-b1_512x512_160k_ade20k
Results:
Dataset: ADE20k
Metrics:
mIoU: 40.97
mIoU(ms+flip): 42.54
Task: Semantic Segmentation
Weights: https://download.openmmlab.com/mmsegmentation/v0.5/segformer/segformer_mit-b1_512x512_160k_ade20k/segformer_mit-b1_512x512_160k_ade20k_20210726_112106-d70e859d.pth
- Config: configs/segformer/segformer_mit-b2_512x512_160k_ade20k.py
In Collection: segformer
Metadata:
backbone: MIT-B2
crop size: (512,512)
inference time (ms/im):
- backend: PyTorch
batch size: 1
hardware: V100
mode: FP32
resolution: (512,512)
value: 32.38
lr schd: 160000
memory (GB): 3.6
Name: segformer_mit-b2_512x512_160k_ade20k
Results:
Dataset: ADE20k
Metrics:
mIoU: 45.58
mIoU(ms+flip): 47.03
Task: Semantic Segmentation
Weights: https://download.openmmlab.com/mmsegmentation/v0.5/segformer/segformer_mit-b2_512x512_160k_ade20k/segformer_mit-b2_512x512_160k_ade20k_20210726_112103-cbd414ac.pth
- Config: configs/segformer/segformer_mit-b3_512x512_160k_ade20k.py
In Collection: segformer
Metadata:
backbone: MIT-B3
crop size: (512,512)
inference time (ms/im):
- backend: PyTorch
batch size: 1
hardware: V100
mode: FP32
resolution: (512,512)
value: 45.23
lr schd: 160000
memory (GB): 4.8
Name: segformer_mit-b3_512x512_160k_ade20k
Results:
Dataset: ADE20k
Metrics:
mIoU: 47.82
mIoU(ms+flip): 48.81
Task: Semantic Segmentation
Weights: https://download.openmmlab.com/mmsegmentation/v0.5/segformer/segformer_mit-b3_512x512_160k_ade20k/segformer_mit-b3_512x512_160k_ade20k_20210726_081410-962b98d2.pth
- Config: configs/segformer/segformer_mit-b4_512x512_160k_ade20k.py
In Collection: segformer
Metadata:
backbone: MIT-B4
crop size: (512,512)
inference time (ms/im):
- backend: PyTorch
batch size: 1
hardware: V100
mode: FP32
resolution: (512,512)
value: 64.72
lr schd: 160000
memory (GB): 6.1
Name: segformer_mit-b4_512x512_160k_ade20k
Results:
Dataset: ADE20k
Metrics:
mIoU: 48.46
mIoU(ms+flip): 49.76
Task: Semantic Segmentation
Weights: https://download.openmmlab.com/mmsegmentation/v0.5/segformer/segformer_mit-b4_512x512_160k_ade20k/segformer_mit-b4_512x512_160k_ade20k_20210728_183055-7f509d7d.pth
- Config: configs/segformer/segformer_mit-b5_512x512_160k_ade20k.py
In Collection: segformer
Metadata:
backbone: MIT-B5
crop size: (512,512)
inference time (ms/im):
- backend: PyTorch
batch size: 1
hardware: V100
mode: FP32
resolution: (512,512)
value: 84.1
lr schd: 160000
memory (GB): 7.2
Name: segformer_mit-b5_512x512_160k_ade20k
Results:
Dataset: ADE20k
Metrics:
mIoU: 49.13
mIoU(ms+flip): 50.22
Task: Semantic Segmentation
Weights: https://download.openmmlab.com/mmsegmentation/v0.5/segformer/segformer_mit-b5_512x512_160k_ade20k/segformer_mit-b5_512x512_160k_ade20k_20210726_145235-94cedf59.pth
- Config: configs/segformer/segformer_mit-b5_640x640_160k_ade20k.py
In Collection: segformer
Metadata:
backbone: MIT-B5
crop size: (640,640)
inference time (ms/im):
- backend: PyTorch
batch size: 1
hardware: V100
mode: FP32
resolution: (640,640)
value: 88.5
lr schd: 160000
memory (GB): 11.5
Name: segformer_mit-b5_640x640_160k_ade20k
Results:
Dataset: ADE20k
Metrics:
mIoU: 49.62
mIoU(ms+flip): 50.36
Task: Semantic Segmentation
Weights: https://download.openmmlab.com/mmsegmentation/v0.5/segformer/segformer_mit-b5_640x640_160k_ade20k/segformer_mit-b5_640x640_160k_ade20k_20210801_121243-41d2845b.pth
......@@ -255,6 +255,36 @@ Examples:
python tools/analyze_logs.py log.json --keys loss --legend loss
```
### Model conversion
`tools/model_converters/` provide several scripts to convert pretrain models released by other repos to MMSegmentation style.
#### ViT Swin MiT Transformer Models
- ViT
`tools/model_converters/vit2mmseg.py` convert keys in timm pretrained vit models to MMSegmentation style.
```shell
python tools/model_converters/vit2mmseg.py ${SRC} ${DST}
```
- Swin
`tools/model_converters/swin2mmseg.py` convert keys in official pretrained swin models to MMSegmentation style.
```shell
python tools/model_converters/swin2mmseg.py ${SRC} ${DST}
```
- SegFormer
`tools/model_converters/mit2mmseg.py` convert keys in official pretrained mit models to MMSegmentation style.
```shell
python tools/model_converters/mit2mmseg.py ${SRC} ${DST}
```
## Model Serving
In order to serve an `MMSegmentation` model with [`TorchServe`](https://pytorch.org/serve/), you can follow the steps:
......
......@@ -259,6 +259,36 @@ python tools/analyze_logs.py xxx.log.json [--keys ${KEYS}] [--legend ${LEGEND}]
python tools/analyze_logs.py log.json --keys loss --legend loss
```
### 转换其他仓库的权重
`tools/model_converters/` 提供了若干个预训练权重转换脚本,支持将其他仓库的预训练权重的 key 转换为与 MMSegmentation 相匹配的 key。
#### ViT Swin MiT Transformer 模型
- ViT
`tools/model_converters/vit2mmseg.py` 将 timm 预训练模型转换到 MMSegmentation。
```shell
python tools/model_converters/vit2mmseg.py ${SRC} ${DST}
```
- Swin
`tools/model_converters/swin2mmseg.py` 将官方预训练模型转换到 MMSegmentation。
```shell
python tools/model_converters/swin2mmseg.py ${SRC} ${DST}
```
- SegFormer
`tools/model_converters/mit2mmseg.py` 将官方预训练模型转换到 MMSegmentation。
```shell
python tools/model_converters/mit2mmseg.py ${SRC} ${DST}
```
## 模型服务
为了用 [`TorchServe`](https://pytorch.org/serve/) 服务 `MMSegmentation` 的模型 , 您可以遵循如下流程:
......
......@@ -23,6 +23,7 @@ Import:
- configs/psanet/psanet.yml
- configs/pspnet/pspnet.yml
- configs/resnest/resnest.yml
- configs/segformer/segformer.yml
- configs/sem_fpn/sem_fpn.yml
- configs/setr/setr.yml
- configs/swin/swin.yml
......
......@@ -5,7 +5,7 @@ from collections import OrderedDict
import torch
def mit_convert(ckpt):
def convert_mit(ckpt):
new_ckpt = OrderedDict()
# Process the concat between q linear weights and kv linear weights
for k, v in ckpt.items():
......@@ -73,5 +73,5 @@ if __name__ == '__main__':
ckpt = torch.load(src_path, map_location='cpu')
ckpt = mit_convert(ckpt)
ckpt = convert_mit(ckpt)
torch.save(ckpt, dst_path)
import argparse
from collections import OrderedDict
import torch
def convert_swin(ckpt):
new_ckpt = OrderedDict()
def correct_unfold_reduction_order(x):
out_channel, in_channel = x.shape
x = x.reshape(out_channel, 4, in_channel // 4)
x = x[:, [0, 2, 1, 3], :].transpose(1,
2).reshape(out_channel, in_channel)
return x
def correct_unfold_norm_order(x):
in_channel = x.shape[0]
x = x.reshape(4, in_channel // 4)
x = x[[0, 2, 1, 3], :].transpose(0, 1).reshape(in_channel)
return x
for k, v in ckpt.items():
if k.startswith('head'):
continue
elif k.startswith('layers'):
new_v = v
if 'attn.' in k:
new_k = k.replace('attn.', 'attn.w_msa.')
elif 'mlp.' in k:
if 'mlp.fc1.' in k:
new_k = k.replace('mlp.fc1.', 'ffn.layers.0.0.')
elif 'mlp.fc2.' in k:
new_k = k.replace('mlp.fc2.', 'ffn.layers.1.')
else:
new_k = k.replace('mlp.', 'ffn.')
elif 'downsample' in k:
new_k = k
if 'reduction.' in k:
new_v = correct_unfold_reduction_order(v)
elif 'norm.' in k:
new_v = correct_unfold_norm_order(v)
else:
new_k = k
new_k = new_k.replace('layers', 'stages', 1)
elif k.startswith('patch_embed'):
new_v = v
if 'proj' in k:
new_k = k.replace('proj', 'projection')
else:
new_k = k
else:
new_v = v
new_k = k
new_ckpt[new_k] = new_v
return new_ckpt
def main():
parser = argparse.ArgumentParser(
description='Convert keys in official pretrained swin models to'
'MMSegmentation style.')
parser.add_argument('src', help='src segmentation model path')
# The dst path must be a full path of the new checkpoint.
parser.add_argument('dst', help='save path')
args = parser.parse_args()
checkpoint = torch.load(args.src, map_location='cpu')
if 'state_dict' in checkpoint:
state_dict = checkpoint['state_dict']
elif 'model' in checkpoint:
state_dict = checkpoint['model']
else:
state_dict = checkpoint
weight = convert_swin(state_dict)
with open(args.dst, 'wb') as f:
torch.save(weight, f)
if __name__ == '__main__':
main()
import argparse
from collections import OrderedDict
import torch
def convert_vit(ckpt):
new_ckpt = OrderedDict()
for k, v in ckpt.items():
if k.startswith('head'):
continue
if k.startswith('norm'):
new_k = k.replace('norm.', 'ln1.')
elif k.startswith('patch_embed'):
if 'proj' in k:
new_k = k.replace('proj', 'projection')
else:
new_k = k
elif k.startswith('blocks'):
if 'norm' in k:
new_k = k.replace('norm', 'ln')
elif 'mlp.fc1' in k:
new_k = k.replace('mlp.fc1', 'ffn.layers.0.0')
elif 'mlp.fc2' in k:
new_k = k.replace('mlp.fc2', 'ffn.layers.1')
elif 'attn.qkv' in k:
new_k = k.replace('attn.qkv.', 'attn.attn.in_proj_')
elif 'attn.proj' in k:
new_k = k.replace('attn.proj', 'attn.attn.out_proj')
else:
new_k = k
new_k = new_k.replace('blocks.', 'layers.')
else:
new_k = k
new_ckpt[new_k] = v
return new_ckpt
def main():
parser = argparse.ArgumentParser(
description='Convert keys in timm pretrained vit models to '
'MMSegmentation style.')
parser.add_argument('src', help='src segmentation model path')
# The dst path must be a full path of the new checkpoint.
parser.add_argument('dst', help='save path')
args = parser.parse_args()
checkpoint = torch.load(args.src, map_location='cpu')
if 'state_dict' in checkpoint:
# timm checkpoint
state_dict = checkpoint['state_dict']
elif 'model' in checkpoint:
# deit checkpoint
state_dict = checkpoint['model']
else:
state_dict = checkpoint
weight = convert_vit(state_dict)
with open(args.dst, 'wb') as f:
torch.save(weight, f)
if __name__ == '__main__':
main()
0% Loading or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment