|
概述
源码地址
torch版本
训练环境没有按照torch的readme一样的环境,自己部署环境为:- torch==1.9.1
- torchvision==0.10.1
- python==3.8.0
- cuda==10.2
- mmcv==0.2.12
- editdistance==0.5.3
- Polygon3==3.0.9.1
- pyclipper==1.3.0
- opencv-python==3.4.2.17
- Cython==0.29.24
复制代码 制作数据集
1、训练的数据集
采用的是rolabelimg进行标注,需要转换为ic2015格式的数据。
转换代码:- import os
- from lxml import etree
- import numpy as np
- import math
- src_xml = "ANN"
- txt_dir = "gt"
- xml_listdir = os.listdir(src_xml)
- xml_listpath = [os.path.join(src_xml,xml_listdir1) for xml_listdir1 in xml_listdir]
- def xml_out(xml_path):
- gt_lines = []
- ET = etree.parse(xml_path)
- objs = ET.findall("object")
- for ix,obj in enumerate(objs):
- name = obj.find("name").text
- robox = obj.find("robndbox")
- cx = int(float(robox.find("cx").text))
- cy = int(float(robox.find("cy").text))
- w = int(float(robox.find("w").text))
- h = int(float(robox.find("h").text))
- angle = float(robox.find("angle").text)
- # angle = math.degrees(angle1)
- wx1 = cx - int(0.5 * w)
- wy1 = cy - int(0.5 * h)
- wx2 = cx + int(0.5 * w)
- wy2 = cy - int(0.5 * h)
- wx3 = cx - int(0.5 * w)
- wy3 = cy + int(0.5 * h)
- wx4 = cx + int(0.5 * w)
- wy4 = cy + int(0.5 * h)
- x1 = int((wx1 - cx) * np.cos(angle) - (wy1 - cy) * np.sin(angle) + cx)
- y1 = int((wx1 - cx) * np.sin(angle) - (wy1 - cy) * np.cos(angle) + cy)
- x2 = int((wx2 - cx) * np.cos(angle) - (wy2 - cy) * np.sin(angle) + cx)
- y2 = int((wx2 - cx) * np.sin(angle) - (wy2 - cy) * np.cos(angle) + cy)
- x3 = int((wx3 - cx) * np.cos(angle) - (wy3 - cy) * np.sin(angle) + cx)
- y3 = int((wx3 - cx) * np.sin(angle) - (wy3 - cy) * np.cos(angle) + cy)
- x4 = int((wx4 - cx) * np.cos(angle) - (wy4 - cy) * np.sin(angle) + cx)
- y4 = int((wx4 - cx) * np.sin(angle) - (wy4 - cy) * np.cos(angle) + cy)
- lines = str(x1)+","+str(y1)+","+str(x2)+","+str(y2)+","+\
- str(x3)+","+str(y3)+","+str(x4)+","+str(y4)+","+str(name)+"\n"
- gt_lines.append(lines)
- return gt_lines
- def main():
- count = 0
- for xml_dir in xml_listdir:
- gt_lines = xml_out(os.path.join(src_xml,xml_dir))
- txt_path = "gt_" + xml_dir[:-4] + ".txt"
- with open(os.path.join(txt_dir,txt_path),"a+") as fd:
- fd.writelines(gt_lines)
- count +=1
- print("Write file %s" % str(count))
- if __name__ == "__main__":
- main()
复制代码 rolabelimg标注后的xml文件和labelimg的xml有些区别,根据不同的标注软件,转换代码略有区别。
转换后的格式为- x1,y1,x2,y2,x3,y3,x4,y4,"classes"
复制代码 ,此处classes为检测的类别,如果是模糊训练的话,classes为“###”。
但是重点,这个源代码对于模糊训练,loss一直为1。
2、将数据集分成训练集和测试集
这里可以按照源码路径存放数据集,也可以修改源码存放位置。
PSENet-python3\dataset\psenet\psenet_ic15.py
修改下述代码为自己文件夹
3、训练
- CUDA_VISIBLE_DEVICES=0,1,2,3 python train.py config/psenet/psenet_r50_ic15_736.py
复制代码 其中根据源码中的readme,
可以根据自己的需要,自行选择配置文件。
4、部署测试
- import torch
- import numpy as np
- import argparse
- import os
- import os.path as osp
- import sys
- import time
- import json
- from mmcv import Config
- import cv2
- from torchvision import transforms
- from dataset import build_data_loader
- from models import build_model
- from models.utils import fuse_module
- from utils import ResultFormat, AverageMeter
- def prepare_image(image, target_size):
- """Do image preprocessing before prediction on any data.
- :param image: original image
- :param target_size: target image size
- :return:
- preprocessed image
- """
- #assert os.path.exists(img), 'file is not exists'
- #img = cv2.imread(img)
- img = cv2.cvtColor(image, cv2.COLOR_BGR2RGB)
- # h, w = image.shape[:2]
- # scale = long_size / max(h, w)
- img = cv2.resize(img, target_size)
- # 将图片由(w,h)变为(1,img_channel,h,w)
- tensor = transforms.ToTensor()(img)
- tensor = tensor.unsqueeze_(0)
- tensor = tensor.to(torch.device("cuda:0"))
- return tensor
- def report_speed(outputs, speed_meters):
- total_time = 0
- for key in outputs:
- if 'time' in key:
- total_time += outputs[key]
- speed_meters[key].update(outputs[key])
- print('%s: %.4f' % (key, speed_meters[key].avg))
- speed_meters['total_time'].update(total_time)
- print('FPS: %.1f' % (1.0 / speed_meters['total_time'].avg))
- def load_model(cfg):
- model = build_model(cfg.model)
- model = model.cuda()
- model.eval()
- checkpoint = "psenet_r50_ic15_1024_finetune/checkpoint_580ep.pth.tar"
- if checkpoint is not None:
- if os.path.isfile(checkpoint):
- print("Loading model and optimizer from checkpoint '{}'".format(checkpoint))
- sys.stdout.flush()
- checkpoint = torch.load(checkpoint)
- d = dict()
- for key, value in checkpoint['state_dict'].items():
- tmp = key[7:]
- d[tmp] = value
- model.load_state_dict(d)
- else:
- print("No checkpoint found at")
- raise
- # fuse conv and bn
- model = fuse_module(model)
- return model
- if __name__ == '__main__':
- src_dir = "testimg/"
- save_dir = "test_save/"
- if not os.path.exists(save_dir):
- os.makedirs(save_dir)
- cfg = Config.fromfile("PSENet/config/psenet/psenet_r50_ic15_1024_finetune.py")
- for d in [cfg, cfg.data.test]:
- d.update(dict(
- report_speed=False
- ))
- if cfg.report_speed:
- speed_meters = dict(
- backbone_time=AverageMeter(500),
- neck_time=AverageMeter(500),
- det_head_time=AverageMeter(500),
- det_pse_time=AverageMeter(500),
- rec_time=AverageMeter(500),
- total_time=AverageMeter(500)
- )
- model = load_model(cfg)
- model.eval()
- count = 0
- for img_name in os.listdir(src_dir):
- img = cv2.imread(src_dir + img_name)
- tensor = prepare_image(img, target_size=(1376, 1024))
- data = dict()
- img_metas = dict()
- data['imgs'] = tensor
- img_metas['org_img_size'] = torch.tensor([[img.shape[0], img.shape[1]]])
- img_metas['img_size'] = torch.tensor([[1376, 1024]])
- data['img_metas'] = img_metas
- data.update(dict(
- cfg=cfg
- ))
- with torch.no_grad():
- outputs = model(**data)
- if cfg.report_speed:
- report_speed(outputs, speed_meters)
- for bboxes in outputs['bboxes']:
- x1 = bboxes[0]
- y1 = bboxes[1]
- x2 = bboxes[4]
- y2 = bboxes[5]
- cv2.rectangle(img, (x1, y1), (x2, y2), (0, 0, 255), 3)
- count = count + 1
- cv2.imwrite(save_dir + img_name, img)
- print("img test:", count)
复制代码- from dataset import build_data_loader
- from models import build_model
- from models.utils import fuse_module
- from utils import ResultFormat, AverageMeter
复制代码 训练代码里含有。
总结
以上为个人经验,希望能给大家一个参考,也希望大家多多支持脚本之家。
来源:https://www.jb51.net/article/283806.htm
免责声明:由于采集信息均来自互联网,如果侵犯了您的权益,请联系我们【E-Mail:cb@itdo.tech】 我们会及时删除侵权内容,谢谢合作! |
本帖子中包含更多资源
您需要 登录 才可以下载或查看,没有账号?立即注册
x
|