Skip to content
Snippets Groups Projects
Unverified Commit 7369d500 authored by MengzhangLI's avatar MengzhangLI Committed by GitHub
Browse files

[Fix] Fix SegLocalVisualizer gt_sem_seg cuda tensor error (#1845)

* [Fix] Fix SegLocalVisualizer gt_sem_seg cuda tensor error

* fix ut error and add config visualizer dict

* fix ut error
parent 5d965083
No related branches found
No related tags found
No related merge requests found
......@@ -4,6 +4,9 @@ env_cfg = dict(
mp_cfg=dict(mp_start_method='fork', opencv_num_threads=0),
dist_cfg=dict(backend='nccl'),
)
vis_backends = [dict(type='LocalVisBackend')]
visualizer = dict(
type='SegLocalVisualizer', vis_backends=vis_backends, name='visualizer')
log_level = 'INFO'
load_from = None
resume = False
......@@ -81,7 +81,7 @@ class SegLocalVisualizer(Visualizer):
"""
num_classes = len(classes)
sem_seg = sem_seg.data
sem_seg = sem_seg.cpu().data
ids = np.unique(sem_seg)[::-1]
legal_indices = ids < num_classes
ids = ids[legal_indices]
......
# Copyright (c) OpenMMLab. All rights reserved.
import os
import os.path as osp
import tempfile
from unittest import TestCase
import cv2
import mmcv
import numpy as np
import pytest
import torch
from mmengine.data import PixelData
......@@ -27,66 +29,55 @@ class TestSegLocalVisualizer(TestCase):
gt_sem_seg_data = dict(data=torch.randint(0, num_class, (1, h, w)))
gt_sem_seg = PixelData(**gt_sem_seg_data)
gt_seg_data_sample = SegDataSample()
gt_seg_data_sample.gt_sem_seg = gt_sem_seg
seg_local_visualizer = SegLocalVisualizer(
vis_backends=[dict(type='LocalVisBackend')], save_dir='temp_dir')
seg_local_visualizer.dataset_meta = dict(
classes=('background', 'foreground'),
palette=[[120, 120, 120], [6, 230, 230]])
seg_local_visualizer.add_datasample(out_file, image,
gt_seg_data_sample)
# test out_file
seg_local_visualizer.add_datasample(out_file, image,
gt_seg_data_sample)
assert os.path.exists(
osp.join('temp_dir' + '/vis_data/vis_image', out_file + '_0.png'))
drawn_img = cv2.imread(
osp.join('temp_dir' + '/vis_data/vis_image', out_file + '_0.png'))
assert drawn_img.shape == (h, w, 3)
os.remove(
osp.join('temp_dir' + '/vis_data/vis_image', out_file + '_0.png'))
os.rmdir('temp_dir' + '/vis_data/vis_image')
# test gt_instances and pred_instances
pred_sem_seg_data = dict(data=torch.randint(0, num_class, (1, h, w)))
pred_sem_seg = PixelData(**pred_sem_seg_data)
pred_seg_data_sample = SegDataSample()
pred_seg_data_sample.pred_sem_seg = pred_sem_seg
seg_local_visualizer.add_datasample(out_file, image,
gt_seg_data_sample,
pred_seg_data_sample)
self._assert_image_and_shape(
osp.join('temp_dir' + '/vis_data/vis_image', out_file + '_0.png'),
(h, w * 2, 3))
seg_local_visualizer.add_datasample(
out_file,
image,
gt_seg_data_sample,
pred_seg_data_sample,
draw_gt=False)
self._assert_image_and_shape(
osp.join('temp_dir' + '/vis_data/vis_image', out_file + '_0.png'),
(h, w, 3))
seg_local_visualizer.add_datasample(
out_file,
image,
gt_seg_data_sample,
pred_seg_data_sample,
draw_pred=False)
self._assert_image_and_shape(
osp.join('temp_dir' + '/vis_data/vis_image', out_file + '_0.png'),
(h, w, 3))
os.rmdir('temp_dir/vis_data')
os.rmdir('temp_dir')
@pytest.mark.parametrize('gt_sem_seg', (gt_sem_seg, gt_sem_seg.cuda()))
def test_add_datasample_forward(gt_sem_seg):
gt_seg_data_sample = SegDataSample()
gt_seg_data_sample.gt_sem_seg = gt_sem_seg
with tempfile.TemporaryDirectory(dir='temp_dir') as tmp_dir:
seg_local_visualizer = SegLocalVisualizer(
vis_backends=[dict(type='LocalVisBackend')],
save_dir=tmp_dir)
seg_local_visualizer.dataset_meta = dict(
classes=('background', 'foreground'),
palette=[[120, 120, 120], [6, 230, 230]])
# test out_file
seg_local_visualizer.add_datasample(out_file, image,
gt_seg_data_sample)
assert os.path.exists(
osp.join(tmp_dir, 'vis_data', 'vis_image',
out_file + '_0.png'))
drawn_img = cv2.imread(
osp.join(tmp_dir, 'vis_data', 'vis_image',
out_file + '_0.png'))
assert drawn_img.shape == (h, w, 3)
# test gt_instances and pred_instances
pred_sem_seg_data = dict(
data=torch.randint(0, num_class, (1, h, w)))
pred_sem_seg = PixelData(**pred_sem_seg_data)
pred_seg_data_sample = SegDataSample()
pred_seg_data_sample.pred_sem_seg = pred_sem_seg
seg_local_visualizer.add_datasample(out_file, image,
gt_seg_data_sample,
pred_seg_data_sample)
self._assert_image_and_shape(
osp.join(tmp_dir, 'vis_data', 'vis_image',
out_file + '_0.png'), (h, w * 2, 3))
seg_local_visualizer.add_datasample(
out_file,
image,
gt_seg_data_sample,
pred_seg_data_sample,
draw_gt=False)
self._assert_image_and_shape(
osp.join(tmp_dir, 'vis_data', 'vis_image',
out_file + '_0.png'), (h, w, 3))
def test_cityscapes_add_datasample(self):
h = 128
......@@ -110,78 +101,67 @@ class TestSegLocalVisualizer(TestCase):
gt_sem_seg_data = dict(data=sem_seg)
gt_sem_seg = PixelData(**gt_sem_seg_data)
gt_seg_data_sample = SegDataSample()
gt_seg_data_sample.gt_sem_seg = gt_sem_seg
seg_local_visualizer = SegLocalVisualizer(
vis_backends=[dict(type='LocalVisBackend')], save_dir='temp_dir')
seg_local_visualizer.dataset_meta = dict(
classes=('road', 'sidewalk', 'building', 'wall', 'fence', 'pole',
'traffic light', 'traffic sign', 'vegetation', 'terrain',
'sky', 'person', 'rider', 'car', 'truck', 'bus', 'train',
'motorcycle', 'bicycle'),
palette=[[128, 64, 128], [244, 35, 232], [70, 70, 70],
[102, 102, 156], [190, 153, 153], [153, 153, 153],
[250, 170, 30], [220, 220, 0], [107, 142, 35],
[152, 251, 152], [70, 130, 180], [220, 20, 60],
[255, 0, 0], [0, 0, 142], [0, 0, 70], [0, 60, 100],
[0, 80, 100], [0, 0, 230], [119, 11, 32]])
seg_local_visualizer.add_datasample(out_file, image,
gt_seg_data_sample)
# test out_file
seg_local_visualizer.add_datasample(out_file, image,
gt_seg_data_sample)
assert os.path.exists(
osp.join('temp_dir' + '/vis_data/vis_image', out_file + '_0.png'))
drawn_img = cv2.imread(
osp.join('temp_dir' + '/vis_data/vis_image', out_file + '_0.png'))
assert drawn_img.shape == (h, w, 3)
os.remove(
osp.join('temp_dir' + '/vis_data/vis_image', out_file + '_0.png'))
os.rmdir('temp_dir/vis_data/vis_image')
# test gt_instances and pred_instances
pred_sem_seg_data = dict(data=torch.randint(0, num_class, (1, h, w)))
pred_sem_seg = PixelData(**pred_sem_seg_data)
pred_seg_data_sample = SegDataSample()
pred_seg_data_sample.pred_sem_seg = pred_sem_seg
seg_local_visualizer.add_datasample(out_file, image,
gt_seg_data_sample,
pred_seg_data_sample)
self._assert_image_and_shape(
osp.join('temp_dir' + '/vis_data/vis_image', out_file + '_0.png'),
(h, w * 2, 3))
seg_local_visualizer.add_datasample(
out_file,
image,
gt_seg_data_sample,
pred_seg_data_sample,
draw_gt=False)
self._assert_image_and_shape(
osp.join('temp_dir' + '/vis_data/vis_image', out_file + '_0.png'),
(h, w, 3))
seg_local_visualizer.add_datasample(
out_file,
image,
gt_seg_data_sample,
pred_seg_data_sample,
draw_pred=False)
self._assert_image_and_shape(
osp.join('temp_dir' + '/vis_data/vis_image', out_file + '_0.png'),
(h, w, 3))
os.rmdir('temp_dir/vis_data')
os.rmdir('temp_dir')
@pytest.mark.parametrize('gt_sem_seg', (gt_sem_seg, gt_sem_seg.cuda()))
def test_cityscapes_add_datasample_forward(gt_sem_seg):
gt_seg_data_sample = SegDataSample()
gt_seg_data_sample.gt_sem_seg = gt_sem_seg
with tempfile.TemporaryDirectory(dir='temp_dir') as tmp_dir:
seg_local_visualizer = SegLocalVisualizer(
vis_backends=[dict(type='LocalVisBackend')],
save_dir='temp_dir')
seg_local_visualizer.dataset_meta = dict(
classes=('road', 'sidewalk', 'building', 'wall', 'fence',
'pole', 'traffic light', 'traffic sign',
'vegetation', 'terrain', 'sky', 'person', 'rider',
'car', 'truck', 'bus', 'train', 'motorcycle',
'bicycle'),
palette=[[128, 64, 128], [244, 35, 232], [70, 70, 70],
[102, 102, 156], [190, 153, 153], [153, 153, 153],
[250, 170, 30], [220, 220, 0], [107, 142, 35],
[152, 251, 152], [70, 130, 180], [220, 20, 60],
[255, 0, 0], [0, 0, 142], [0, 0, 70],
[0, 60, 100], [0, 80, 100], [0, 0, 230],
[119, 11, 32]])
seg_local_visualizer.add_datasample(out_file, image,
gt_seg_data_sample)
# test out_file
seg_local_visualizer.add_datasample(out_file, image,
gt_seg_data_sample)
assert os.path.exists(
osp.join(tmp_dir, 'vis_data', 'vis_image',
out_file + '_0.png'))
drawn_img = cv2.imread(
osp.join(tmp_dir, 'vis_data', 'vis_image',
out_file + '_0.png'))
assert drawn_img.shape == (h, w, 3)
# test gt_instances and pred_instances
pred_sem_seg_data = dict(
data=torch.randint(0, num_class, (1, h, w)))
pred_sem_seg = PixelData(**pred_sem_seg_data)
pred_seg_data_sample = SegDataSample()
pred_seg_data_sample.pred_sem_seg = pred_sem_seg
seg_local_visualizer.add_datasample(out_file, image,
gt_seg_data_sample,
pred_seg_data_sample)
self._assert_image_and_shape(
osp.join(tmp_dir, 'vis_data', 'vis_image',
out_file + '_0.png'), (h, w * 2, 3))
seg_local_visualizer.add_datasample(
out_file,
image,
gt_seg_data_sample,
pred_seg_data_sample,
draw_gt=False)
self._assert_image_and_shape(
osp.join(tmp_dir, 'vis_data', 'vis_image',
out_file + '_0.png'), (h, w, 3))
def _assert_image_and_shape(self, out_file, out_shape):
assert os.path.exists(out_file)
drawn_img = cv2.imread(out_file)
assert drawn_img.shape == out_shape
os.remove(out_file)
os.rmdir('temp_dir/vis_data/vis_image')
0% Loading or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment