mmdetection

可视化

左边为原图,右边为结果图 

如果有GT-json,可画出json

import os
import cv2
import mmcv
import torch
import numpy as np
from pycocotools.coco import COCO, maskUtils
from mmdet.apis import init_detector, inference_detector, show_result

from mmdet.apis import init_dist, inference
from mmdet.core import results2json, coco_eval, wrap_fp16_model
from mmdet.datasets import build_dataloader, build_dataset

import colorsys
import random
def get_n_hls_colors(num):
    hls_colors = []
    i = 0
    step = 360.0 / num
    while i  360:
        h = i
        s = 90 + random.random() * 10
        l = 50 + random.random() * 10
        _hlsc = [h / 360.0, l / 100.0, s / 100.0]
        hls_colors.append(_hlsc)
        i += step
    return hls_colors

def ncolors(num):
    rgb_colors = []
    if num  1:
        return rgb_colors
    hls_colors = get_n_hls_colors(num)
    for hlsc in hls_colors:
        _r, _g, _b = colorsys.hls_to_rgb(hlsc[0], hlsc[1], hlsc[2])
        r, g, b = [int(x * 255.0) for x in (_r, _g, _b)]
        rgb_colors.append([r, g, b])
    return rgb_colors

def mkdir_os(path):
    if not os.path.exists(path):
        os.makedirs(path)

# 模型文件
config_file = 'configs/nut5_fine_faster_rcnn_r50_fpn_1x.py'
checkpoint_file = 'checkpoints/epoch_55.pth'
model = init_detector(config_file, checkpoint_file, device='cuda:0')
model.eval()
score_thr=0.3
test_json = 'instances_test2017.json'

loadflag = 0
if loadflag == 1:
    save_path = "./result"
    mkdir_os(save_path)
    # 通过遍历点位id载入数据name
    path = "/home/test_infe"
    imgs = []
    trainimg = os.listdir(path)
    for lab in range(len(trainimg)):
        subname = trainimg[lab]
        name = os.path.join(path, subname)
        imgs.append(name)

    #遍历测试集
    num = 0
    count = len(imgs)
    for lab in range(count):
        print(num,'/',count)
        num += 1
        img = imgs[lab]
        result = inference_detector(model, img)

        # 用于通过左右方式显示原图和可视化图
        img = mmcv.imread(img)
        img = img.copy()
        oriimg = img.copy()

        # 判断bbox和segm
        if isinstance(result, tuple):
            bbox_result, segm_result = result
        else:
            bbox_result, segm_result = result, None
        bboxes = np.vstack(bbox_result)

        # 因为这是刹车片的mask-rcnn测试inference.仅仅画出mask
        if segm_result is not None:
            segms = mmcv.concat_list(segm_result)
            inds = np.where(bboxes[:, -1]  score_thr)[0]
            np.random.seed(42)
            color_masks = np.random.randint(0, 256, (1, 3), dtype=np.uint8)
            for i in inds:
                i = int(i)
                mask = maskUtils.decode(segms[i]).astype(np.bool)
                img[mask] = img[mask] * 0.5 + color_masks * 0.5

        if score_thr  0:
            assert bboxes.shape[1] == 5
            scores = bboxes[:, -1]
            inds = scores  score_thr
            bboxes = bboxes[inds, :]

        font_scale = 0.8
        thickness = 4
        bbox_color = (0, 255, 0)
        text_color = (0, 255, 0)
        for bbox in bboxes:
            bbox_int = bbox.astype(np.int32)
            left_top = (bbox_int[0], bbox_int[1])
            right_bottom = (bbox_int[2], bbox_int[3])
            cv2.rectangle(
                img, left_top, right_bottom, bbox_color, thickness=thickness)
            if len(bbox)  4:
                label_text = '{:.02f}'.format(bbox[-1])
            cv2.putText(img, label_text, (bbox_int[0], bbox_int[1] - 5),
                        cv2.FONT_HERSHEY_COMPLEX, font_scale, text_color)

        # 可视化显示mask + bbox
        h1, w1 = oriimg.shape[:2]
        h5, w2 = img.shape[:2]
        vis = np.zeros((max(h1, h5), w1 + w2, 3), np.uint8)
        vis[:h1, :w1, :] = oriimg
        vis[:h5, w1:w1 + w2, :] = img

        # 保存结果
        out_file = os.path.join(save_path, 'result_{}.jpg'.format(lab))
        cv2.imwrite(out_file, vis)
else:
    save_path = "./result"
    mkdir_os(save_path)
    path = "/home/train2017"
    cnum = 8
    self_color = ncolors(cnum)
    colorbar_vis = np.zeros((cnum*30, 100, 3), dtype=np.uint8)
    for ind,colo in enumerate(self_color):
        k_tm = np.ones((30, 100, 3), dtype=np.uint8) * np.array([colo[-1], colo[-2], colo[-3]])
        colorbar_vis[ind*30:(ind+1)*30, 0:100] = k_tm
    cv2.imwrite('./colorbar_vis.png', colorbar_vis)

    eval_types = ['bbox', 'segm']
    cfg = mmcv.Config.fromfile(config_file)
    # set cudnn_benchmark
    if cfg.get('cudnn_benchmark', False):
        torch.backends.cudnn.benchmark = True
    cfg.model.pretrained = None
    cfg.data.test.test_mode = True
    dataset = build_dataset(cfg.data.test)
    coco = dataset.coco

    outputs = []
    #遍历测试集
    num = 0
    count = len(dataset)
    for idx in range(len(dataset)):
        print(num,'/',count)
        num += 1
        img_id = dataset.img_ids[idx]
        img_info = dataset.coco.loadImgs(img_id)[0]
        file_name = img_info['file_name']
        img_path = os.path.join(path, file_name)
        result = inference_detector(model, img_path)

        outputs.append(result)

        # 用于通过左右方式显示原图和可视化图
        img = mmcv.imread(img_path)
        img = img.copy()
        oriimg = img.copy()

        # 判断bbox和segm
        if isinstance(result, tuple):
            bbox_result, segm_result = result
        else:
            bbox_result, segm_result = result, None
        bboxes = np.vstack(bbox_result)

        # 因为这是刹车片的mask-rcnn测试inference.仅仅画出mask
        if segm_result is not None:
            segms = mmcv.concat_list(segm_result)
            inds = np.where(bboxes[:, -1]  score_thr)[0]
            np.random.seed(42)
            color_masks = np.random.randint(0, 256, (1, 3), dtype=np.uint8)
            for i in inds:
                i = int(i)
                mask = maskUtils.decode(segms[i]).astype(np.bool)
                img[mask] = img[mask] * 0.5 + color_masks * 0.5

        if score_thr  0:
            assert bboxes.shape[1] == 5
            scores = bboxes[:, -1]
            inds = scores  score_thr
            bboxes = bboxes[inds, :]

        font_scale = 0.8
        thickness = 4
        bbox_color = (0, 255, 0)
        text_color = (0, 255, 0)
        for bbox in bboxes:
            bbox_int = bbox.astype(np.int32)
            left_top = (bbox_int[0], bbox_int[1])
            right_bottom = (bbox_int[2], bbox_int[3])
            cv2.rectangle(
                img, left_top, right_bottom, bbox_color, thickness=thickness)
            if len(bbox)  4:
                label_text = '{:.02f}'.format(bbox[-1])
            cv2.putText(img, label_text, (bbox_int[0], bbox_int[1] - 5),
                        cv2.FONT_HERSHEY_COMPLEX, font_scale, text_color)

        # 显示GT
        annIds = coco.getAnnIds(imgIds=img_info['id'], catIds=[], iscrowd=None)
        anns = coco.loadAnns(annIds)

        categories = coco.dataset['categories']

        polygons = []
        color = []
        category_id_list = []
        for ann in anns:
            if 'segmentation' in ann:
                if type(ann['segmentation']) == list:
                    # polygon
                    for seg in ann['segmentation']:
                        poly = np.array(seg).reshape((int(len(seg) / 2), 2))
                        poly_list = poly.tolist()
                        polygons.append(poly_list)

                        # rgb-bgr
                        # mylist[start:end:step]
                        # 切片逆序[::-1]
                        if ann['iscrowd'] == 0 and ann["ignore"] == 0:
                            temp = self_color[ann['category_id']]
                            color.append(temp[::-1])
                        if ann['iscrowd'] == 1 or ann["ignore"] == 1:
                            temp = self_color[-1]
                            color.append(temp[::-1])
                        category_id_list.append(ann['category_id'])
                else:
                    print("-------------")
                    exit()
                    # # mask
                    # img_id = dataset.img_ids[idx]
                    # img_info = dataset.coco.loadImgs(img_id)[0]
                    # t = imgIds[ann['image_id']]
                    # if type(ann['segmentation']['counts']) == list:
                    #     rle = maskUtils.frPyObjects([ann['segmentation']], t['height'], t['width'])
                    # else:
                    #     rle = [ann['segmentation']]
                    # m = maskUtils.decode(rle)
                    #
                    # if ann['iscrowd'] == 0 and ann["ignore"] == 0:
                    #     color_mask = np.array([0, 0, 255])
                    # if ann['iscrowd'] == 1 or ann["ignore"] == 1:
                    #     color_mask = np.array([0, 255, 255])
                    # mask = m.astype(np.bool)
                    # cvImage[mask] = cvImage[mask] * 0.7 + color_mask * 0.3

        point_size = 2
        thickness = 2
        for key in range(len(polygons)):
            ndata = polygons[key]
            cur_color = color[key]
            label_id = category_id_list[key]

            label = 'error'
            for m_id in categories:
                if m_id['id']==label_id:
                    label = m_id['name']

            #segmentation
            if len(ndata)2:
                for k in range(len(ndata)):
                    data = ndata[k]
                    cv2.circle(oriimg, (data[0], data[1]), point_size, (cur_color[0], cur_color[1], cur_color[2]),
                               thickness)
            else:#bbox
                cv2.rectangle(oriimg, (int(ndata[0][0]), int(ndata[0][1])), (int(ndata[1][0]), int(ndata[1][1])),
                              (cur_color[0], cur_color[1], cur_color[2]),
                              thickness)
                cv2.putText(oriimg, label, (int(ndata[0][0]), int(ndata[0][1])),
                            cv2.FONT_HERSHEY_SIMPLEX, 0.8, (cur_color[0], cur_color[1], cur_color[2]), 2)

        # 可视化显示mask + bbox
        h1, w1 = oriimg.shape[:2]
        h5, w2 = img.shape[:2]
        vis = np.zeros((max(h1, h5), w1 + w2, 3), np.uint8)
        vis[:h1, :w1, :] = oriimg
        vis[:h5, w1:w1 + w2, :] = img

        # 保存结果
        out_file = os.path.join(save_path, 'result_{}.jpg'.format(idx))
        cv2.imwrite(out_file, vis)

 

最新回复(0)
/jishu9mvwZeYlJb47XrqWMMbZOl_2BEst4ndbnUdqYASw_3D_3D4795321
8 简首页