# -*- coding: utf-8 -*-
"""
author: 赵杭天
data: 20.1
base: official baseline
"""

from reader import data_loader
from glob import glob
from sklearn.cluster import KMeans

xmls_train = glob('./insects/train/annotations/xmls/*.xml')
xmls_val = glob('./insects/val/annotations/xmls/*.xml')

# print(len(xmls_train), len(xmls_val), len(xmls))
import xml.etree.ElementTree as ET
import os
import numpy as np

ct = 0
# records = [[{...}{...},...] , [{...},{...},...] ],第一个list放train的dict,第二个list放val的dict
records = [[],[]]

INSECT_NAMES = ['Boerner', 'Leconte', 'Linnaeus',
                'acuminatus', 'armandi', 'coleoptera', 'linnaeus']

def get_insect_names():
    """
    return a dict, as following,
        {'Boerner': 0,
         'Leconte': 1,
         'Linnaeus': 2,
         'acuminatus': 3,
         'armandi': 4,
         'coleoptera': 5,
         'linnaeus': 6
        }
    It can map the insect name into an integer label.
    """
    insect_category2id = {}
    for i, item in enumerate(INSECT_NAMES):
        insect_category2id[item] = i

    return insect_category2id

for split_section, xmls in enumerate([xmls_train, xmls_val]):
    cname2cid = get_insect_names()
    for fpath in xmls:
        tree = ET.parse(fpath)

        if tree.find('id') is None:
            im_id = np.array([ct])
        else:
            im_id = np.array([int(tree.find('id').text)])

        objs = tree.findall('object')
        im_w = float(tree.find('size').find('width').text)
        im_h = float(tree.find('size').find('height').text)
        gt_bbox = np.zeros((len(objs), 4), dtype=np.float32)
        gt_class = np.zeros((len(objs),), dtype=np.int32)
        is_crowd = np.zeros((len(objs),), dtype=np.int32)
        difficult = np.zeros((len(objs),), dtype=np.int32)
        for i, obj in enumerate(objs):
            cname = obj.find('name').text
            gt_class[i] = cname2cid[cname]
            _difficult = int(obj.find('difficult').text)
            x1 = float(obj.find('bndbox').find('xmin').text)
            y1 = float(obj.find('bndbox').find('ymin').text)
            x2 = float(obj.find('bndbox').find('xmax').text)
            y2 = float(obj.find('bndbox').find('ymax').text)
            x1 = max(0, x1)
            y1 = max(0, y1)
            x2 = min(im_w - 1, x2)
            y2 = min(im_h - 1, y2)
            # xywh
            gt_bbox[i] = [(x1 + x2) / 2.0, (y1 + y2) / 2.0, x2 - x1 + 1., y2 - y1 + 1.]
            is_crowd[i] = 0
            difficult[i] = _difficult

        voc_rec = {
            'im_id': im_id,
            'h': im_h,
            'w': im_w,
            'is_crowd': is_crowd,
            'gt_class': gt_class,
            'gt_bbox': gt_bbox,
            'gt_poly': [],
            'difficult': difficult
        }
        if len(objs) != 0:
            records[split_section].append(voc_rec)
        ct += 1

wh_total = [[],[]]
num_total = [[],[]]
for split_section, section_records in enumerate(records):
    print('%s gt bbox数量:' % ('train' if split_section == 0 else 'val')) 
    # wh means width and height
    wh_cls_wise = [[] for i in range(len(INSECT_NAMES))]
    for r in section_records:
        for i in range(len(r['gt_bbox'])):
            bb = r['gt_bbox'][i][2:]
            bb[0] = bb[0] / r['w']
            bb[1] = bb[1] / r['h']
            wh_total[split_section].append(bb)
            wh_cls_wise[r['gt_class'][i]].append(bb)

    for i, c in enumerate(wh_cls_wise):
        print(INSECT_NAMES[i], ':\t', len(c))
        num_total[split_section].append(len(c))

    print('Total gt bbox: %d \n' % len(wh_total[split_section]))

# 绘图
import matplotlib.pyplot as plt

name_list = INSECT_NAMES
num_list_1 = num_total[0]
num_list_2 = num_total[1]
x =list(range(len(num_list_1)))
total_width, n = 0.8, 2
width = total_width / n

plt.figure(figsize=(10,8))
plt.bar(x, num_list_1, width=width, label='train', fc = 'g')
for i in range(len(x)):
    x[i] = x[i] + width
plt.bar(x, num_list_2, width=width, label='val',tick_label = name_list,fc = 'y')
plt.legend()
# plt.show()

anchor_num = 9
kmeans = KMeans(n_clusters=anchor_num, random_state=0).fit(wh_total[0]+wh_total[1])
cluster_results = list(608 * kmeans.cluster_centers_)
cluster_results.sort(key=lambda wh: wh[1] * wh[0])
print('anchor_num = %d, 聚类结果:%s' % (anchor_num,cluster_results))

import matplotlib.patches as patches
plt.figure(2)
plt.imshow(np.full((100,100,3), 255))
currentAxis=plt.gca()
for bbox in cluster_results:
    w, h = bbox
    rect=patches.Rectangle((50-w/2, 50-h/2),w,h,linewidth=1,edgecolor='r',facecolor='none')
    currentAxis.add_patch(rect)
plt.show()

各类别train/val统计柱状图

bbox聚类结果

打印统计结果

最后修改日期: 2020年3月1日

作者

留言

撰写回覆或留言

发布留言必须填写的电子邮件地址不会公开。