← 返回首页
Yolov8+DeepSort实现车流量统计
发表时间:2024-07-17 16:15:06
Yolov8+DeepSort实现车流量统计

在ByteTrack被提出之前,可以说DeepSORT是最好的目标跟踪算法之一。本文,我们就来应用这个算法实现车流量统计。

1.DeepSORT简介

DeepSORT,全称为Deep SORT,是在SORT(Simple Online and Realtime Tracking)基础上引入深度学习特征的改进版本,由Alex Bewley等人提出。它结合了深度学习的强大表征能力与传统方法的高效性,旨在实现更精确、更稳定的多目标追踪。DeepSORT不仅继承了SORT的实时性优势,还通过深度特征的融入显著提高了追踪的鲁棒性和重识别能力,特别是在目标外观变化较大或存在短暂遮挡的情况下。

DeepSORT的官方网址如下:

https://github.com/nwojke/deep_sort

但在这里,我们不使用官方的代码,而使用第三方代码,其Github网址为:

https://github.com/levan92/deep_sort_realtime

1). 安装DeepSORT

pip install deep-sort-realtime

可以看出,DeepSORT算法只是需要几个常规的软件包:numpy、scipy和opencv-python,对用户十分友好。

使用DeepSORT也很方便,先导入DeepSORT:

from deep_sort_realtime.deepsort_tracker import DeepSort

2).DeepSORT 核心参数

实例化:

tracker = DeepSort()

DeepSort有一些输入参数,在这里只介绍几个常用的参数:

目标跟踪:

tracks = tracker.update_tracks(bbs, frame=frame)

bbs为目标检测器的结果列表,每个结果是一个元组,形式为([left,top,w,h],置信值,类型),其中类型为字符串型。frame为帧图像,输出tracks为目标跟踪结果,使用for循环可以得到各个目标的跟踪信息:

for track in tracks:

下面介绍一些track的常用属性和方法:

2.代码实现

下面我们就给出DeepSORT实现车流量统计的完整程序,在这里,我们仍然使用YOLOv8作为目标检测器:

import numpy as np
import cv2
from ultralytics import YOLO
from deep_sort_realtime.deepsort_tracker import DeepSort

model = YOLO('yolov8l.pt')

cap = cv2.VideoCapture("./video/street_car_001.mp4")
fps = cap.get(cv2.CAP_PROP_FPS)
size = (int(cap.get(cv2.CAP_PROP_FRAME_WIDTH)), int(cap.get(cv2.CAP_PROP_FRAME_HEIGHT)))
fNUMS = cap.get(cv2.CAP_PROP_FRAME_COUNT)
fourcc = cv2.VideoWriter_fourcc(*'H264')
videoWriter = cv2.VideoWriter("./traffic_result.mp4", fourcc, fps, size)

tracker = DeepSort(max_age=5)

counter_set=set()

def box_label(image, box, label='', color=(128, 128, 128), txt_color=(255, 255, 255)):
    p1, p2 = (int(box[0]), int(box[1])), (int(box[2]), int(box[3]))
    cv2.rectangle(image, p1, p2, color, thickness=1, lineType=cv2.LINE_AA)
    if label:
        w, h = cv2.getTextSize(label, 0, fontScale=2 / 3, thickness=1)[0]
        outside = p1[1] - h >= 3
        p2 = p1[0] + w, p1[1] - h - 3 if outside else p1[1] + h + 3
        cv2.rectangle(image, p1, p2, color, -1, cv2.LINE_AA)
        cv2.putText(image,
                    label, (p1[0], p1[1] - 2 if outside else p1[1] + h + 2),
                    0, 2 / 3, txt_color, thickness=1, lineType=cv2.LINE_AA)


while cap.isOpened():
    success, frame = cap.read()

    if success:
        results = model(frame, conf=0.4)
        outputs = results[0].boxes.data.cpu().numpy()

        detections = []

        if outputs is not None:
            for output in outputs:
                x1, y1, x2, y2 = list(map(int, output[:4]))
                if output[5] == 2:
                    detections.append(([x1, y1, int(x2 - x1), int(y2 - y1)], output[4], 'car'))
                elif output[5] == 5:
                    detections.append(([x1, y1, int(x2 - x1), int(y2 - y1)], output[4], 'bus'))
                elif output[5] == 7:
                    detections.append(([x1, y1, int(x2 - x1), int(y2 - y1)], output[4], 'truck'))

            tracks = tracker.update_tracks(detections, frame=frame)

            for track in tracks:
                if not track.is_confirmed():
                    continue
                track_id = track.track_id
                counter_set.add(track_id)
                bbox = track.to_ltrb()

                box_label(frame, bbox, '#' + str(int(track_id)) + track.det_class, (167, 146, 11))
        counter = len(counter_set)
        cv2.putText(frame, f"vehicle count:{counter}", (25, 50),
                    cv2.FONT_HERSHEY_SIMPLEX, 1, (0, 0, 255), 2)

        cv2.imshow("traffic detected", frame)
        videoWriter.write(frame)

        '''
        if cv2.waitKey(1) & 0xFF == ord("q"):
            break
        '''
        # waitKey() 函数的功能是不断刷新图像 , 频率时间为delay , 单位为ms
        cv2.waitKey(1)
        if cv2.getWindowProperty('traffic detected', cv2.WND_PROP_VISIBLE) < 1:
            break
    else:
        break

cap.release()
videoWriter.release()
cv2.destroyAllWindows()

实现效果: