翼度科技»论坛 编程开发 python 查看内容

pytorch版本PSEnet训练并部署方式

4

主题

4

帖子

12

积分

新手上路

Rank: 1

积分
12
概述

源码地址
torch版本
训练环境没有按照torch的readme一样的环境,自己部署环境为:
  1. torch==1.9.1
  2. torchvision==0.10.1
  3. python==3.8.0
  4. cuda==10.2
  5. mmcv==0.2.12
  6. editdistance==0.5.3
  7. Polygon3==3.0.9.1
  8. pyclipper==1.3.0
  9. opencv-python==3.4.2.17
  10. Cython==0.29.24
复制代码
  1. ./compile.sh
复制代码
制作数据集


1、训练的数据集

采用的是rolabelimg进行标注,需要转换为ic2015格式的数据。
转换代码:
  1. import os
  2. from lxml import etree
  3. import numpy as np
  4. import math
  5. src_xml = "ANN"
  6. txt_dir = "gt"
  7. xml_listdir = os.listdir(src_xml)
  8. xml_listpath = [os.path.join(src_xml,xml_listdir1) for xml_listdir1 in xml_listdir]
  9. def xml_out(xml_path):
  10.     gt_lines = []
  11.     ET = etree.parse(xml_path)
  12.     objs = ET.findall("object")
  13.     for ix,obj in enumerate(objs):
  14.         name = obj.find("name").text
  15.         robox = obj.find("robndbox")
  16.         cx = int(float(robox.find("cx").text))
  17.         cy = int(float(robox.find("cy").text))
  18.         w = int(float(robox.find("w").text))
  19.         h = int(float(robox.find("h").text))
  20.         angle = float(robox.find("angle").text)
  21.         # angle = math.degrees(angle1)
  22.         wx1 = cx - int(0.5 * w)
  23.         wy1 = cy - int(0.5 * h)
  24.         wx2 = cx + int(0.5 * w)
  25.         wy2 = cy - int(0.5 * h)
  26.         wx3 = cx - int(0.5 * w)
  27.         wy3 = cy + int(0.5 * h)
  28.         wx4 = cx + int(0.5 * w)
  29.         wy4 = cy + int(0.5 * h)
  30.         x1 = int((wx1 - cx) * np.cos(angle) - (wy1 - cy) * np.sin(angle) + cx)
  31.         y1 = int((wx1 - cx) * np.sin(angle) - (wy1 - cy) * np.cos(angle) + cy)
  32.         x2 = int((wx2 - cx) * np.cos(angle) - (wy2 - cy) * np.sin(angle) + cx)
  33.         y2 = int((wx2 - cx) * np.sin(angle) - (wy2 - cy) * np.cos(angle) + cy)
  34.         x3 = int((wx3 - cx) * np.cos(angle) - (wy3 - cy) * np.sin(angle) + cx)
  35.         y3 = int((wx3 - cx) * np.sin(angle) - (wy3 - cy) * np.cos(angle) + cy)
  36.         x4 = int((wx4 - cx) * np.cos(angle) - (wy4 - cy) * np.sin(angle) + cx)
  37.         y4 = int((wx4 - cx) * np.sin(angle) - (wy4 - cy) * np.cos(angle) + cy)
  38.         lines = str(x1)+","+str(y1)+","+str(x2)+","+str(y2)+","+\
  39.                 str(x3)+","+str(y3)+","+str(x4)+","+str(y4)+","+str(name)+"\n"
  40.         gt_lines.append(lines)
  41.         return gt_lines
  42. def main():
  43.     count = 0
  44.     for xml_dir in xml_listdir:
  45.         gt_lines = xml_out(os.path.join(src_xml,xml_dir))
  46.         txt_path = "gt_" + xml_dir[:-4] + ".txt"
  47.         with open(os.path.join(txt_dir,txt_path),"a+") as fd:
  48.             fd.writelines(gt_lines)
  49.         count +=1
  50.         print("Write file %s" % str(count))
  51. if __name__ == "__main__":
  52.     main()
复制代码
rolabelimg标注后的xml文件和labelimg的xml有些区别,根据不同的标注软件,转换代码略有区别。
转换后的格式为
  1. x1,y1,x2,y2,x3,y3,x4,y4,"classes"
复制代码
,此处classes为检测的类别,如果是模糊训练的话,classes为“###”。
但是重点,这个源代码对于模糊训练,loss一直为1。

2、将数据集分成训练集和测试集


这里可以按照源码路径存放数据集,也可以修改源码存放位置。
PSENet-python3\dataset\psenet\psenet_ic15.py
修改下述代码为自己文件夹


3、训练
  1. CUDA_VISIBLE_DEVICES=0,1,2,3 python train.py config/psenet/psenet_r50_ic15_736.py
复制代码
其中根据源码中的readme,

可以根据自己的需要,自行选择配置文件。


4、部署测试
  1. import torch
  2. import numpy as np
  3. import argparse
  4. import os
  5. import os.path as osp
  6. import sys
  7. import time
  8. import json
  9. from mmcv import Config
  10. import cv2
  11. from torchvision import transforms
  12. from dataset import build_data_loader
  13. from models import build_model
  14. from models.utils import fuse_module
  15. from utils import ResultFormat, AverageMeter
  16. def prepare_image(image, target_size):
  17.     """Do image preprocessing before prediction on any data.
  18.     :param image:       original image
  19.     :param target_size: target image size
  20.     :return:
  21.                         preprocessed image
  22.     """
  23.     #assert os.path.exists(img), 'file is not exists'
  24.     #img = cv2.imread(img)
  25.     img = cv2.cvtColor(image, cv2.COLOR_BGR2RGB)
  26.     # h, w = image.shape[:2]
  27.     # scale = long_size / max(h, w)
  28.     img = cv2.resize(img, target_size)
  29.     # 将图片由(w,h)变为(1,img_channel,h,w)
  30.     tensor = transforms.ToTensor()(img)
  31.     tensor = tensor.unsqueeze_(0)
  32.     tensor = tensor.to(torch.device("cuda:0"))
  33.     return tensor
  34. def report_speed(outputs, speed_meters):
  35.     total_time = 0
  36.     for key in outputs:
  37.         if 'time' in key:
  38.             total_time += outputs[key]
  39.             speed_meters[key].update(outputs[key])
  40.             print('%s: %.4f' % (key, speed_meters[key].avg))
  41.     speed_meters['total_time'].update(total_time)
  42.     print('FPS: %.1f' % (1.0 / speed_meters['total_time'].avg))
  43. def load_model(cfg):
  44.     model = build_model(cfg.model)
  45.     model = model.cuda()
  46.     model.eval()
  47.     checkpoint = "psenet_r50_ic15_1024_finetune/checkpoint_580ep.pth.tar"
  48.     if checkpoint is not None:
  49.         if os.path.isfile(checkpoint):
  50.             print("Loading model and optimizer from checkpoint '{}'".format(checkpoint))
  51.             sys.stdout.flush()
  52.             checkpoint = torch.load(checkpoint)
  53.             d = dict()
  54.             for key, value in checkpoint['state_dict'].items():
  55.                 tmp = key[7:]
  56.                 d[tmp] = value
  57.             model.load_state_dict(d)
  58.         else:
  59.             print("No checkpoint found at")
  60.             raise
  61.         # fuse conv and bn
  62.     model = fuse_module(model)
  63.     return model
  64. if __name__ == '__main__':
  65.     src_dir = "testimg/"
  66.     save_dir = "test_save/"
  67.     if not os.path.exists(save_dir):
  68.         os.makedirs(save_dir)
  69.     cfg = Config.fromfile("PSENet/config/psenet/psenet_r50_ic15_1024_finetune.py")
  70.     for d in [cfg, cfg.data.test]:
  71.         d.update(dict(
  72.             report_speed=False
  73.         ))
  74.     if cfg.report_speed:
  75.         speed_meters = dict(
  76.             backbone_time=AverageMeter(500),
  77.             neck_time=AverageMeter(500),
  78.             det_head_time=AverageMeter(500),
  79.             det_pse_time=AverageMeter(500),
  80.             rec_time=AverageMeter(500),
  81.             total_time=AverageMeter(500)
  82.         )
  83.     model = load_model(cfg)
  84.     model.eval()
  85.     count = 0
  86.     for img_name in os.listdir(src_dir):
  87.         img = cv2.imread(src_dir + img_name)
  88.         tensor = prepare_image(img, target_size=(1376, 1024))
  89.         data = dict()
  90.         img_metas = dict()
  91.         data['imgs'] = tensor
  92.         img_metas['org_img_size'] = torch.tensor([[img.shape[0], img.shape[1]]])
  93.         img_metas['img_size'] = torch.tensor([[1376, 1024]])
  94.         data['img_metas'] = img_metas
  95.         data.update(dict(
  96.             cfg=cfg
  97.         ))
  98.         with torch.no_grad():
  99.             outputs = model(**data)
  100.         if cfg.report_speed:
  101.             report_speed(outputs, speed_meters)
  102.         for bboxes in outputs['bboxes']:
  103.             x1 = bboxes[0]
  104.             y1 = bboxes[1]
  105.             x2 = bboxes[4]
  106.             y2 = bboxes[5]
  107.             cv2.rectangle(img, (x1, y1), (x2, y2), (0, 0, 255), 3)
  108.         count = count + 1
  109.         cv2.imwrite(save_dir + img_name, img)
  110.         print("img test:", count)
复制代码
  1. from dataset import build_data_loader
  2. from models import build_model
  3. from models.utils import fuse_module
  4. from utils import ResultFormat, AverageMeter
复制代码
训练代码里含有。

总结

以上为个人经验,希望能给大家一个参考,也希望大家多多支持脚本之家。

来源:https://www.jb51.net/article/283806.htm
免责声明:由于采集信息均来自互联网,如果侵犯了您的权益,请联系我们【E-Mail:cb@itdo.tech】 我们会及时删除侵权内容,谢谢合作!

本帖子中包含更多资源

您需要 登录 才可以下载或查看,没有账号?立即注册

x

举报 回复 使用道具