Skip to content
Snippets Groups Projects
Commit 2eaf7ee8 authored by 谢昕辰's avatar 谢昕辰 Committed by MeowZheng
Browse files

[Doc] Update inference doc (#2052)


* refactor inference doc

* introduce segdatasample

* add vis example

* fix

* Update docs/en/user_guides/3_inference.md

* update link

Co-authored-by: default avatarMiao Zheng <76149310+MeowZheng@users.noreply.github.com>
parent b8d87d7f
No related branches found
No related tags found
No related merge requests found
...@@ -2,24 +2,142 @@ ...@@ -2,24 +2,142 @@
MMSegmentation provides pre-trained models for semantic segmentation in [Model Zoo](../model_zoo.md), and supports multiple standard datasets, including Cityscapes, ADE20K, etc. MMSegmentation provides pre-trained models for semantic segmentation in [Model Zoo](../model_zoo.md), and supports multiple standard datasets, including Cityscapes, ADE20K, etc.
This note will show how to use existing models to inference on given images. This note will show how to use existing models to inference on given images.
As for how to test existing models on standard datasets, please see this [guide](./4_train_test.md#Test-models-on-standard-datasets) As for how to test existing models on standard datasets, please see this [guide](./4_train_test.md)
## Inference on given images ## Inference API
MMSegmentation provides high-level Python APIs for inference on images. Here is an example of building the model and inference on given images. MMSegmentation provides several interfaces for users to easily use pre-trained models for inference.
Please download the [pre-trained model](https://download.openmmlab.com/mmsegmentation/v0.5/pspnet/pspnet_r50-d8_512x1024_80k_cityscapes/pspnet_r50-d8_512x1024_80k_cityscapes_20200606_112131-2376f12b.pth) to the path specified by `checkpoint_file` first.
- [mmseg.apis.init_model](#mmsegapisinit_model)
- [mmseg.apis.inference_model](#mmsegapisinference_model)
- [mmseg.apis.show_result_pyplot](#mmsegapisshow_result_pyplot)
### mmseg.apis.init_model
Initialize a segmentor from config file.
Parameters:
- config (str, `Path`, or `mmengine.Config`) - Config file path or the config object.
- checkpoint (str, optional) - Checkpoint path. If left as None, the model will not load any weights.
- device (str, optional) - CPU/CUDA device option. Default 'cuda:0'.
- cfg_options (dict, optional) - Options to override some settings in the used config.
Returns:
- nn.Module: The constructed segmentor.
Example:
```python
from mmseg.apis import init_model
from mmseg.utils import register_all_modules
config_path = 'configs/pspnet/pspnet_r50-d8_4xb2-40k_cityscapes-512x1024.py'
checkpoint_path = 'checkpoints/pspnet_r50-d8_512x1024_40k_cityscapes_20200605_003338-2966598c.pth'
# register all modules in mmseg into the registries
register_all_modules()
# initialize model without checkpoint
model = init_model(config_path)
# init model and load checkpoint
model = init_model(config_path, checkpoint_path)
# init model and load checkpoint on CPU
model = init_model(config_path, checkpoint_path, 'cpu')
```
### mmseg.apis.inference_model
Inference image(s) with the segmentor.
Parameters:
- model (nn.Module) - The loaded segmentor
- imgs (str, np.ndarray, or list\[str/np.ndarray\]) - Either image files or loaded images
Returns:
- `SegDataSample` or list\[`SegDataSample`\]: If imgs is a list or tuple, the same length list type results will be returned, otherwise return the segmentation results directly.
**Note:** [SegDataSample](https://github.com/open-mmlab/mmsegmentation/blob/1.x/mmseg/structures/seg_data_sample.py) is a data structure interface of MMSegmentation, it is used as interfaces between different components. `SegDataSample` implement the abstract data element `mmengine.structures.BaseDataElement`, please refer to data element [documentation](https://mmengine.readthedocs.io/en/latest/advanced_tutorials/data_element.html) in [MMEngine](https://github.com/open-mmlab/mmengine) for more information.
The attributes in `SegDataSample` are divided into several parts:
- `gt_sem_seg` (`PixelData`) - Ground truth of semantic segmentation.
- `pred_sem_seg` (`PixelData`) - Prediction of semantic segmentation.
- `seg_logits` (`PixelData`) - Predicted logits of semantic segmentation.
**Note** [PixelData](https://github.com/open-mmlab/mmengine/blob/main/mmengine/structures/pixel_data.py) is the data structure for pixel-level annotations or predictions, please refer to PixelData [documentation](https://mmengine.readthedocs.io/en/latest/advanced_tutorials/data_element.html) in [MMEngine](https://github.com/open-mmlab/mmengine) for more information.
Example:
```python ```python
from mmseg.apis import init_model, inference_model from mmseg.apis import init_model, inference_model
from mmseg.utils import register_all_modules from mmseg.utils import register_all_modules
# Specify the path to model config and checkpoint file
config_file = 'configs/pspnet/pspnet_r50-d8_512x1024_80k_cityscapes.py' config_path = 'configs/pspnet/pspnet_r50-d8_4xb2-40k_cityscapes-512x1024.py'
checkpoint_file = 'checkpoints/pspnet_r50-d8_512x1024_80k_cityscapes_20200606_112131-2376f12b.pth' checkpoint_path = 'checkpoints/pspnet_r50-d8_512x1024_40k_cityscapes_20200605_003338-2966598c.pth'
img_path = 'demo/demo.png'
# register all modules in mmseg into the registries # register all modules in mmseg into the registries
register_all_modules() register_all_modules()
model = init_model(config_path, checkpoint_path)
result = inference_model(model, img_path)
```
### mmseg.apis.show_result_pyplot
Visualize the segmentation results on the image.
Parameters:
- model (nn.Module) - The loaded segmentor.
- img (str or np.ndarray) - Image filename or loaded image.
- result (`SegDataSample`) - The prediction SegDataSample result.
- opacity (float) - Opacity of painted segmentation map. Default `0.5`, must be in `(0, 1]` range.
- title (str) - The title of pyplot figure. Default is ''.
- draw_gt (bool) - Whether to draw GT SegDataSample. Default to `True`.
- draw_pred (draws_pred) - Whether to draw Prediction SegDataSample. Default to `True`.
- wait_time (float) - The interval of show (s), 0 is the special value that means "forever". Default to `0`.
- show (bool) - Whether to display the drawn image. Default to `True`.
- save_dir (str, optional) - Save file dir for all storage backends. If it is `None`, the backend storage will not save any data.
- out_file (str, optional) - Path to output file. Default to `None`.
Returns:
- np.ndarray: the drawn image which channel is RGB.
Example:
```python
from mmseg.apis import init_model, inference_model, show_result_pyplot
from mmseg.utils import register_all_modules
config_path = 'configs/pspnet/pspnet_r50-d8_4xb2-40k_cityscapes-512x1024.py'
checkpoint_path = 'checkpoints/pspnet_r50-d8_512x1024_40k_cityscapes_20200605_003338-2966598c.pth'
img_path = 'demo/demo.png'
# register all modules in mmseg into the registries
register_all_modules()
# build the model from a config file and a checkpoint file # build the model from a config file and a checkpoint file
model = init_model(config_file, checkpoint_file, device='cuda:0') model = init_model(config_path, checkpoint_path, device='cuda:0')
# test image pair, and save the results
img = 'demo/demo.png' # inference on given image
result = inference_model(model, img) result = inference_model(model, img_path)
# display the segmentation result
vis_image = show_result_pyplot(model, img_path, result)
# save the visualization result, the output image would be found at the path `work_dirs/result.png`
vis_iamge = show_result_pyplot(model, img_path, result, out_file='work_dirs/result.png')
# Modify the time of displaying images, note that 0 is the special value that means "forever".
vis_image = show_result_pyplot(model, img_path, result, wait_time=5)
``` ```
**Note:** If your current device doesn't have graphical user interface, it is recommended that setting `show` to `False` and specify the `out_file` or `save_dir` to save the results. If you would like to display the result on a window, no special settings are required.
...@@ -169,12 +169,16 @@ def show_result_pyplot(model: BaseSegmentor, ...@@ -169,12 +169,16 @@ def show_result_pyplot(model: BaseSegmentor,
draw_gt (bool): Whether to draw GT SegDataSample. Default to True. draw_gt (bool): Whether to draw GT SegDataSample. Default to True.
draw_pred (bool): Whether to draw Prediction SegDataSample. draw_pred (bool): Whether to draw Prediction SegDataSample.
Defaults to True. Defaults to True.
wait_time (float): The interval of show (s). Defaults to 0. wait_time (float): The interval of show (s). 0 is the special value
that means "forever". Defaults to 0.
show (bool): Whether to display the drawn image. show (bool): Whether to display the drawn image.
Default to True. Default to True.
save_dir (str, optional): Save file dir for all storage backends. save_dir (str, optional): Save file dir for all storage backends.
If it is None, the backend storage will not save any data. If it is None, the backend storage will not save any data.
out_file (str, optional): Path to output file. Default to None. out_file (str, optional): Path to output file. Default to None.
Returns:
np.ndarray: the drawn image which channel is RGB.
""" """
if hasattr(model, 'module'): if hasattr(model, 'module'):
model = model.module model = model.module
......
0% Loading or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment