From e385842557766763f2eba660498ff0a26d0f4848 Mon Sep 17 00:00:00 2001 From: David de la Iglesia Castro <daviddelaiglesiacastro@gmail.com> Date: Thu, 24 Sep 2020 19:34:40 +0200 Subject: [PATCH] Fix cpu inference (#152) * Add missing map_location * Add docstring * Update mmseg/apis/inference.py Co-authored-by: Jerry Jiarui XU <xvjiarui0826@gmail.com> * Update inference.py * Update inference.py Co-authored-by: Jerry Jiarui XU <xvjiarui0826@gmail.com> --- mmseg/apis/inference.py | 5 +++-- 1 file changed, 3 insertions(+), 2 deletions(-) diff --git a/mmseg/apis/inference.py b/mmseg/apis/inference.py index 3ba6b62c..6fa7e3b3 100644 --- a/mmseg/apis/inference.py +++ b/mmseg/apis/inference.py @@ -16,7 +16,8 @@ def init_segmentor(config, checkpoint=None, device='cuda:0'): object. checkpoint (str, optional): Checkpoint path. If left as None, the model will not load any weights. - + device (str, optional) CPU/CUDA device option. Default 'cuda:0'. + Use 'cpu' for loading model on CPU. Returns: nn.Module: The constructed segmentor. """ @@ -28,7 +29,7 @@ def init_segmentor(config, checkpoint=None, device='cuda:0'): config.model.pretrained = None model = build_segmentor(config.model, test_cfg=config.test_cfg) if checkpoint is not None: - checkpoint = load_checkpoint(model, checkpoint) + checkpoint = load_checkpoint(model, checkpoint, map_location='cpu') model.CLASSES = checkpoint['meta']['CLASSES'] model.PALETTE = checkpoint['meta']['PALETTE'] model.cfg = config # save the config in the model for convenience -- GitLab