coco数据集通过不同文件名id挑选train和test

1.coco数据集通过不同文件名id挑选train和test
# -*- coding: utf-8 -*-
import os
import cv2
import json
import argparse
import numpy as np
import random
from pycocotools.coco import COCO, maskUtils

def main(args):

    #测试集 L-id R-id 每个id两个数据
    test_key = [str(m_key)+"-"+str(n_key) for m_key in range(0, 100) for n_key in range(0, 1000)]
    test_dict = {}.fromkeys(test_key, 0)

    coco = COCO(args.input_json)
    catIds = []
    imgIds = coco.getImgIds(catIds=catIds)

    # 打乱imgIds的list,让test的选择更随机
    random.seed(10086)
    random.shuffle(imgIds)

    images_select_test = []
    annotations_select_test = []

    image_id_select_test = 0
    annotation_id_select_test = 0


    images_select_train = []
    annotations_select_train = []

    image_id_select_train = 0
    annotation_id_select_train = 0


    for i in range(len(imgIds)):

        img_info_append = []
        new_anns = []

        if i % 100 == 0:
            print(i, "/", len(imgIds))

        img_info = coco.loadImgs(imgIds[i])[0]

        file_name = img_info['file_name']

        temp = file_name.split('_')
        if temp[0] != 'data-augmentation':
            test_key = temp[2]
            if test_dict[test_key]  2:
                test_dict[test_key] += 1

                annIds = coco.getAnnIds(imgIds=img_info['id'], catIds=catIds, iscrowd=None)
                anns = coco.loadAnns(annIds)

                img_info_temp = img_info.copy()
                img_info_temp['id'] = image_id_select_test
                img_info_temp['file_name'] = file_name
                img_info_append.append(img_info_temp)

                for index in range(len(anns)):
                    ann = anns[index]

                    if 'segmentation' not in ann:  # 只处理存在annotation的情况
                        if type(ann['segmentation']) != list:
                            print("error no segmentation")
                            exit()

                    ann_temp = ann.copy()

                    ann_temp['id'] = annotation_id_select_test  # 对象ID
                    ann_temp['image_id'] = image_id_select_test  # 图片ID
                    new_anns.append(ann_temp)

                    annotation_id_select_test += 1

                image_id_select_test += 1
                annotations_select_test.extend(new_anns)
                images_select_test.extend(img_info_append)

            else:

                annIds = coco.getAnnIds(imgIds=img_info['id'], catIds=catIds, iscrowd=None)
                anns = coco.loadAnns(annIds)

                img_info_temp = img_info.copy()
                img_info_temp['id'] = image_id_select_train
                img_info_temp['file_name'] = file_name
                img_info_append.append(img_info_temp)

                for index in range(len(anns)):
                    ann = anns[index]

                    if 'segmentation' not in ann:  # 只处理存在annotation的情况
                        if type(ann['segmentation']) != list:
                            print("error no segmentation")
                            exit()

                    ann_temp = ann.copy()

                    ann_temp['id'] = annotation_id_select_train  # 对象ID
                    ann_temp['image_id'] = image_id_select_train  # 图片ID
                    new_anns.append(ann_temp)

                    annotation_id_select_train += 1

                image_id_select_train += 1
                annotations_select_train.extend(new_anns)
                images_select_train.extend(img_info_append)

    print("test image num:", image_id_select_test)
    print("train image num:", image_id_select_train)

    instance_select_test2017 = {}
    instance_select_test2017['license'] = ['license']
    instance_select_test2017['info'] = 'spytensor created'
    instance_select_test2017['categories'] = coco.dataset['categories']
    instance_select_test2017['images'] = images_select_test
    instance_select_test2017['annotations'] = annotations_select_test

    instance_select_train2017 = {}
    instance_select_train2017['license'] = ['license']
    instance_select_train2017['info'] = 'spytensor created'
    instance_select_train2017['categories'] = coco.dataset['categories']
    instance_select_train2017['images'] = images_select_train
    instance_select_train2017['annotations'] = annotations_select_train

    import io
    with io.open(args.output_json, 'w', encoding="utf-8") as outfile:
        my_json_str = json.dumps(instance_select_test2017, ensure_ascii=False, indent=1)
        outfile.write(my_json_str)

    with io.open('./annotations/instances_train2017_select-test.json', 'w', encoding="utf-8") as outfile:
        my_json_str = json.dumps(instance_select_train2017, ensure_ascii=False, indent=1)
        outfile.write(my_json_str)



if __name__ == "__main__":

    parser = argparse.ArgumentParser(
        description=
        "coco vis")
    parser.add_argument('-ij',
                        "--input_json",
                        default = './annotations/instances_train2017.json',
                        help="set input folder1")
    parser.add_argument('-oj',
                        "--output_json",
                        default = './deadline-20201125_instances_test2017_ZD-data.json',
                        help="set output folder")
    args = parser.parse_args()

    if args.output_json is None:
        parser.print_help()
        exit()

    main(args)
2.coco数据集将json文件按不同类别分割成类别json
# -*- coding: utf-8 -*-
import os
import cv2
import argparse
import numpy as np
from pycocotools.coco import COCO, maskUtils
import json

all_categories = [
    {
        "name": "nut",
        "id": 1
    },
    {
        "name": "wires_x",
        "id": 2
    },
    {
        "name": "dustproof",
        "id": 3
    },
    {
        "name": "nizi",
        "id": 4
    },
    {
        "name": "gate_nut",
        "id": 5
    }
 ]

def main(argv):

    coco = COCO(argv.json_file)

    for m_key, m_val in enumerate(all_categories):
        catIds = coco.getCatIds(catNms=[m_val['id']])
        imgIds = coco.getImgIds(catIds=catIds)

        images_select_test = []
        annotations_select_test = []

        image_id_select_test = 0
        annotation_id_select_test = 0


        for i in range(len(imgIds)):
            img_info_append = []
            new_anns = []
            img_info = coco.loadImgs(imgIds[i])[0]

            file_name = img_info['file_name']

            # cvImage = cv2.imread(os.path.join(argv.input_file, file_name), -1)
            #
            # if cvImage is None:
            #     print('if cvImage is None:', file_name)
            #     exit()

            annIds = coco.getAnnIds(imgIds=img_info['id'], catIds=catIds, iscrowd=None)
            anns = coco.loadAnns(annIds)

            img_info_temp = img_info.copy()
            img_info_temp['id'] = image_id_select_test
            img_info_temp['file_name'] = file_name
            img_info_append.append(img_info_temp)

            for index in range(len(anns)):
                ann = anns[index]

                if 'segmentation' not in ann:  # 只处理存在annotation的情况
                    if type(ann['segmentation']) != list:
                        print("error no segmentation")
                        exit()

                ann_temp = ann.copy()

                ann_temp['id'] = annotation_id_select_test  # 对象ID
                ann_temp['image_id'] = image_id_select_test  # 图片ID
                new_anns.append(ann_temp)

                annotation_id_select_test += 1

            image_id_select_test += 1
            annotations_select_test.extend(new_anns)
            images_select_test.extend(img_info_append)

        instance_select_test2017 = {}
        instance_select_test2017['license'] = ['license']
        instance_select_test2017['info'] = 'spytensor created'
        instance_select_test2017['categories'] = coco.dataset['categories']
        instance_select_test2017['images'] = images_select_test
        instance_select_test2017['annotations'] = annotations_select_test

        import io
        with io.open(os.path.join(args.output_file, "{}_test2017.json".format(m_val['name'])), 'w', encoding="utf-8") as outfile:
            my_json_str = json.dumps(instance_select_test2017, ensure_ascii=False, indent=1)
            outfile.write(my_json_str)


if __name__ == "__main__":

    parser = argparse.ArgumentParser(
        description=
        "coco vis")
    parser.add_argument('-if',
                        "--input_file",
                        default='./train2017/',
                        help="set input folder1")
    parser.add_argument('-oj',
                        "--json_file",
                        default='./annotations/batch5-ZD-data_instances_test2017.json',
                        help="set input json")
    parser.add_argument('-of',
                        "--output_file",
                        default='./resule_json/',
                        help="set output folder")
    args = parser.parse_args()

    if args.output_file is None:
        parser.print_help()
        exit()

    main(args)

 

最新回复(0)
/jishubpbD2AjUYG_2BK3ckHZ04dRYstnp0c2kk6iMKJASTS338_3D4795319
8 简首页