0

0

详解使用PyTorch实现目标检测与跟踪

coldplay.xixi

coldplay.xixi

发布时间:2020-12-11 17:18:45

|

9180人浏览过

|

来源于CSDN

转载

python教程栏目介绍使用PyTorch实现目标检测与跟踪

详解使用PyTorch实现目标检测与跟踪

大量免费学习推荐,敬请访问python教程(视频)

引言

在昨天的文章中,我们介绍了如何在PyTorch中使用您自己的图像来训练图像分类器,然后使用它来进行图像识别。本文将展示如何使用预训练的分类器检测图像中的多个对象,并在视频中跟踪它们。

图像中的目标检测

目标检测的算法有很多,YOLO跟SSD是现下最流行的算法。在本文中,我们将使用YOLOv3。在这里我们不会详细讨论YOLO,如果想对它有更多了解,可以参考下面的链接哦~(https://pjreddie.com/darknet/yolo/)

下面让我们开始吧,依然从导入模块开始:

from models import *
from utils import *
import os, sys, time, datetime, random
import torch
from torch.utils.data import DataLoader
from torchvision import datasets, transforms
from torch.autograd import Variable
import matplotlib.pyplot as plt
import matplotlib.patches as patches
from PIL import Image

然后加载预训练的配置和权重,以及一些预定义的值,包括:图像的尺寸、置信度阈值和非最大抑制阈值。

config_path='config/yolov3.cfg'
weights_path='config/yolov3.weights'
class_path='config/coco.names'
img_size=416
conf_thres=0.8
nms_thres=0.4
# Load model and weights
model = Darknet(config_path, img_size=img_size)
model.load_weights(weights_path)
model.cuda()
model.eval()
classes = utils.load_classes(class_path)
Tensor = torch.cuda.FloatTensor

下面的函数将返回对指定图像的检测结果。

def detect_image(img):
    # scale and pad image
    ratio = min(img_size/img.size[0], img_size/img.size[1])
    imw = round(img.size[0] * ratio)
    imh = round(img.size[1] * ratio)
    img_transforms=transforms.Compose([transforms.Resize((imh,imw)),
         transforms.Pad((max(int((imh-imw)/2),0), 
              max(int((imw-imh)/2),0), max(int((imh-imw)/2),0),
              max(int((imw-imh)/2),0)), (128,128,128)),
         transforms.ToTensor(),
         ])
    # convert image to Tensor
    image_tensor = img_transforms(img).float()
    image_tensor = image_tensor.unsqueeze_(0)
    input_img = Variable(image_tensor.type(Tensor))
    # run inference on the model and get detections
    with torch.no_grad():
        detections = model(input_img)
        detections = utils.non_max_suppression(detections, 80, 
                        conf_thres, nms_thres)
    return detections[0]

最后,让我们通过加载一个图像,获取检测结果,然后用检测到的对象周围的包围框来显示它。并为不同的类使用不同的颜色来区分。

# load image and get detections
img_path = "images/blueangels.jpg"
prev_time = time.time()
img = Image.open(img_path)
detections = detect_image(img)
inference_time = datetime.timedelta(seconds=time.time() - prev_time)
print ('Inference Time: %s' % (inference_time))
# Get bounding-box colors
cmap = plt.get_cmap('tab20b')
colors = [cmap(i) for i in np.linspace(0, 1, 20)]
img = np.array(img)
plt.figure()
fig, ax = plt.subplots(1, figsize=(12,9))
ax.imshow(img)
pad_x = max(img.shape[0] - img.shape[1], 0) * (img_size / max(img.shape))
pad_y = max(img.shape[1] - img.shape[0], 0) * (img_size / max(img.shape))
unpad_h = img_size - pad_y
unpad_w = img_size - pad_x
if detections is not None:
    unique_labels = detections[:, -1].cpu().unique()
    n_cls_preds = len(unique_labels)
    bbox_colors = random.sample(colors, n_cls_preds)
    # browse detections and draw bounding boxes
    for x1, y1, x2, y2, conf, cls_conf, cls_pred in detections:
        box_h = ((y2 - y1) / unpad_h) * img.shape[0]
        box_w = ((x2 - x1) / unpad_w) * img.shape[1]
        y1 = ((y1 - pad_y // 2) / unpad_h) * img.shape[0]
        x1 = ((x1 - pad_x // 2) / unpad_w) * img.shape[1]
        color = bbox_colors[int(np.where(
             unique_labels == int(cls_pred))[0])]
        bbox = patches.Rectangle((x1, y1), box_w, box_h,
             linewidth=2, edgecolor=color, facecolor='none')
        ax.add_patch(bbox)
        plt.text(x1, y1, s=classes[int(cls_pred)], 
                color='white', verticalalignment='top',
                bbox={'color': color, 'pad': 0})
plt.axis('off')
# save image
plt.savefig(img_path.replace(".jpg", "-det.jpg"),        
                  bbox_inches='tight', pad_inches=0.0)
plt.show()

下面是我们的一些检测结果:

详解使用PyTorch实现目标检测与跟踪

详解使用PyTorch实现目标检测与跟踪

详解使用PyTorch实现目标检测与跟踪

视频中的目标跟踪

现在你知道了如何在图像中检测不同的物体。当你在一个视频中一帧一帧地看时,你会看到那些跟踪框在移动。但是如果这些视频帧中有多个对象,你如何知道一个帧中的对象是否与前一个帧中的对象相同?这被称为目标跟踪,它使用多次检测来识别一个特定的对象。

有多种算法可以做到这一点,在本文中决定使用SORT(Simple Online and Realtime Tracking),它使用Kalman滤波器预测先前识别的目标的轨迹,并将其与新的检测结果进行匹配,非常方便且速度很快。

现在开始编写代码,前3个代码段将与单幅图像检测中的代码段相同,因为它们处理的是在单帧上获得 YOLO 检测。差异在最后一部分出现,对于每个检测,我们调用 Sort 对象的 Update 函数,以获得对图像中对象的引用。因此,与前面示例中的常规检测(包括边界框的坐标和类预测)不同,我们将获得跟踪的对象,除了上面的参数,还包括一个对象 ID。并且需要使用OpenCV来读取视频并显示视频帧。

videopath = 'video/interp.mp4'
%pylab inline 
import cv2
from IPython.display import clear_output
cmap = plt.get_cmap('tab20b')
colors = [cmap(i)[:3] for i in np.linspace(0, 1, 20)]
# initialize Sort object and video capture
from sort import *
vid = cv2.VideoCapture(videopath)
mot_tracker = Sort()
#while(True):
for ii in range(40):
    ret, frame = vid.read()
    frame = cv2.cvtColor(frame, cv2.COLOR_BGR2RGB)
    pilimg = Image.fromarray(frame)
    detections = detect_image(pilimg)
    img = np.array(pilimg)
    pad_x = max(img.shape[0] - img.shape[1], 0) * 
            (img_size / max(img.shape))
    pad_y = max(img.shape[1] - img.shape[0], 0) * 
            (img_size / max(img.shape))
    unpad_h = img_size - pad_y
    unpad_w = img_size - pad_x
    if detections is not None:
        tracked_objects = mot_tracker.update(detections.cpu())
        unique_labels = detections[:, -1].cpu().unique()
        n_cls_preds = len(unique_labels)
        for x1, y1, x2, y2, obj_id, cls_pred in tracked_objects:
            box_h = int(((y2 - y1) / unpad_h) * img.shape[0])
            box_w = int(((x2 - x1) / unpad_w) * img.shape[1])
            y1 = int(((y1 - pad_y // 2) / unpad_h) * img.shape[0])
            x1 = int(((x1 - pad_x // 2) / unpad_w) * img.shape[1])
            color = colors[int(obj_id) % len(colors)]
            color = [i * 255 for i in color]
            cls = classes[int(cls_pred)]
            cv2.rectangle(frame, (x1, y1), (x1+box_w, y1+box_h),
                         color, 4)
            cv2.rectangle(frame, (x1, y1-35), (x1+len(cls)*19+60,
                         y1), color, -1)
            cv2.putText(frame, cls + "-" + str(int(obj_id)), 
                        (x1, y1 - 10), cv2.FONT_HERSHEY_SIMPLEX, 
                        1, (255,255,255), 3)
    fig=figure(figsize=(12, 8))
    title("Video Stream")
    imshow(frame)
    show()
    clear_output(wait=True)

相关免费学习推荐:php编程(视频)

ArrowMancer
ArrowMancer

手机上的宇宙动作RPG,游戏角色和元素均为AI生成

下载

热门AI工具

更多
DeepSeek
DeepSeek

幻方量化公司旗下的开源大模型平台

豆包大模型
豆包大模型

字节跳动自主研发的一系列大型语言模型

通义千问
通义千问

阿里巴巴推出的全能AI助手

腾讯元宝
腾讯元宝

腾讯混元平台推出的AI助手

文心一言
文心一言

文心一言是百度开发的AI聊天机器人,通过对话可以生成各种形式的内容。

讯飞写作
讯飞写作

基于讯飞星火大模型的AI写作工具,可以快速生成新闻稿件、品宣文案、工作总结、心得体会等各种文文稿

即梦AI
即梦AI

一站式AI创作平台,免费AI图片和视频生成。

ChatGPT
ChatGPT

最最强大的AI聊天机器人程序,ChatGPT不单是聊天机器人,还能进行撰写邮件、视频脚本、文案、翻译、代码等任务。

相关专题

更多
拼多多赚钱的5种方法 拼多多赚钱的5种方法
拼多多赚钱的5种方法 拼多多赚钱的5种方法

在拼多多上赚钱主要可以通过无货源模式一件代发、精细化运营特色店铺、参与官方高流量活动、利用拼团机制社交裂变,以及成为多多进宝推广员这5种方法实现。核心策略在于通过低成本、高效率的供应链管理与营销,利用平台社交电商红利实现盈利。

28

2026.01.26

edge浏览器怎样设置主页 edge浏览器自定义设置教程
edge浏览器怎样设置主页 edge浏览器自定义设置教程

在Edge浏览器中设置主页,请依次点击右上角“...”图标 > 设置 > 开始、主页和新建标签页。在“Microsoft Edge 启动时”选择“打开以下页面”,点击“添加新页面”并输入网址。若要使用主页按钮,需在“外观”设置中开启“显示主页按钮”并设定网址。

8

2026.01.26

苹果官方查询网站 苹果手机正品激活查询入口
苹果官方查询网站 苹果手机正品激活查询入口

苹果官方查询网站主要通过 checkcoverage.apple.com/cn/zh/ 进行,可用于查询序列号(SN)对应的保修状态、激活日期及技术支持服务。此外,查找丢失设备请使用 iCloud.com/find,购买信息与物流可访问 Apple (中国大陆) 订单状态页面。

31

2026.01.26

npd人格什么意思 npd人格有什么特征
npd人格什么意思 npd人格有什么特征

NPD(Narcissistic Personality Disorder)即自恋型人格障碍,是一种心理健康问题,特点是极度夸大自我重要性、需要过度赞美与关注,同时极度缺乏共情能力,背后常掩藏着低自尊和不安全感,影响人际关系、工作和生活,通常在青少年时期开始显现,需由专业人士诊断。

3

2026.01.26

windows安全中心怎么关闭 windows安全中心怎么执行操作
windows安全中心怎么关闭 windows安全中心怎么执行操作

关闭Windows安全中心(Windows Defender)可通过系统设置暂时关闭,或使用组策略/注册表永久关闭。最简单的方法是:进入设置 > 隐私和安全性 > Windows安全中心 > 病毒和威胁防护 > 管理设置,将实时保护等选项关闭。

5

2026.01.26

2026年春运抢票攻略大全 春运抢票攻略教你三招手【技巧】
2026年春运抢票攻略大全 春运抢票攻略教你三招手【技巧】

铁路12306提供起售时间查询、起售提醒、购票预填、候补购票及误购限时免费退票五项服务,并强调官方渠道唯一性与信息安全。

35

2026.01.26

个人所得税税率表2026 个人所得税率最新税率表
个人所得税税率表2026 个人所得税率最新税率表

以工资薪金所得为例,应纳税额 = 应纳税所得额 × 税率 - 速算扣除数。应纳税所得额 = 月度收入 - 5000 元 - 专项扣除 - 专项附加扣除 - 依法确定的其他扣除。假设某员工月工资 10000 元,专项扣除 1000 元,专项附加扣除 2000 元,当月应纳税所得额为 10000 - 5000 - 1000 - 2000 = 2000 元,对应税率为 3%,速算扣除数为 0,则当月应纳税额为 2000×3% = 60 元。

12

2026.01.26

oppo云服务官网登录入口 oppo云服务登录手机版
oppo云服务官网登录入口 oppo云服务登录手机版

oppo云服务https://cloud.oppo.com/可以在云端安全存储您的照片、视频、联系人、便签等重要数据。当您的手机数据意外丢失或者需要更换手机时,可以随时将这些存储在云端的数据快速恢复到手机中。

40

2026.01.26

抖币充值官方网站 抖币性价比充值链接地址
抖币充值官方网站 抖币性价比充值链接地址

网页端充值步骤:打开浏览器,输入https://www.douyin.com,登录账号;点击右上角头像,选择“钱包”;进入“充值中心”,操作和APP端一致。注意:切勿通过第三方链接、二维码充值,谨防受骗

7

2026.01.26

热门下载

更多
网站特效
/
网站源码
/
网站素材
/
前端模板

精品课程

更多
相关推荐
/
热门推荐
/
最新课程
关于我们 免责申明 举报中心 意见反馈 讲师合作 广告合作 最新更新
php中文网:公益在线php培训,帮助PHP学习者快速成长!
关注服务号 技术交流群
PHP中文网订阅号
每天精选资源文章推送

Copyright 2014-2026 https://www.php.cn/ All Rights Reserved | php.cn | 湘ICP备2023035733号