重构yolov5-6.0 detect.py进行目标检测

代码实现

模型参数获取

按照惯例我们先导入包

1
2
3
4
5
import torch
from config import *
from models.experimental import attempt_load
from utils.general import check_img_size
from utils.torch_utils import select_device

接下来获取 . pt模型的参数,此函数返回model, device, half, stride, names, imgsz对应值。

1
2
3
4
5
6
7
8
9
10
11
12
13
14
def get_model():
device = select_device('')
# print(device.type)

half = device.type != 'cpu'

model = torch.jit.load(WEIGHTS) if 'torchscript' in WEIGHTS else attempt_load(WEIGHTS, map_location=device)
stride = int(model.stride.max()) # model stride
names = model.module.names if hasattr(model, 'module') else model.names # get class names
if half:
model.half()
imgsz = check_img_size(IMGSZ, s=stride)

return model, device, half, stride, names, imgsz

对目标图像进行检查

导入包

1
2
3
4
5
6
import numpy as np
import torch
from config import CONF_THRES, IOU_THRES, LINE_THICKNESS, HIDE_LABELS, HIDE_CONF
from utils.augmentations import letterbox
from utils.general import non_max_suppression, scale_coords, xyxy2xywh
from utils.plots import Annotator, colors

此函数需要传入img0需要被检测的图像;model, device, half, stride, names, imgsz为 .pt模型的参数,并返回标记好的图像跟坐标信息(详情见最后一行注释)

1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
@torch.no_grad()
def out_img(img0, model, device, half, stride, names, imgsz):
# Padded resize
img = letterbox(img0, imgsz, stride=stride, auto=True)[0]
# Convert
img = img.transpose((2, 0, 1))[::-1] # HWC to CHW, BGR to RGB
img = np.ascontiguousarray(img)
model(torch.zeros(1, 3, *imgsz).to(device).type_as(next(model.parameters()))) # run once
img = torch.from_numpy(img).to(device)
img = img.half() if half else img.float() # uint8 to fp16/32
# 归一化处理
img = img / 255.0 # 0 - 255 to 0.0 - 1.0

if len(img.shape) == 3:
img = img[None] # expand for batch dim

pred = model(img, augment=False, visualize=False)[0]
# NMS
pred = non_max_suppression(pred, CONF_THRES, IOU_THRES, None, False, max_det=1000)
# Process predictions
det = pred[0]

im0 = img0.copy()
gn = torch.tensor(im0.shape)[[1, 0, 1, 0]] # normalization gain whwh

annotator = Annotator(im0, line_width=LINE_THICKNESS, example=str(names))
xywh_list = []
if len(det):
# Rescale boxes from img_size to im0 size
det[:, :4] = scale_coords(img.shape[2:], det[:, :4], im0.shape).round()
# Write results
for *xyxy, conf, cls in reversed(det):
c = int(cls) # integer class
xywh = (xyxy2xywh(torch.tensor(xyxy).view(1, 4)) / gn).view(-1).tolist() # normalized xywh
xywh_list.append(xywh)
label = None if HIDE_LABELS else (names[c] if HIDE_CONF else f'{names[c]} {conf:.2f}')
annotator.box_label(xyxy, label, color=colors(c, True))
im0 = annotator.result()
return im0, xywh_list #返回的im0为标记好的图像文件,xywh_list为标记的坐标信息(x为图像中心点所在像素的横坐标,图像左上角为原点,y同理,w为被标记图像在图像中的像素宽,h为高度)

main函数调用

可以利用while循环中的 cv2.imshow(window_name, img)进行展示。

也可以利用img.save(img_outpath)保存到本地指定目录。

或者调用此函数完成一些其他功能