深度学习系列53:mmdetection上手

1. 安装

使用openmim安装:

pip install -U openmim
mim install "mmengine>=0.7.0"
mim install "mmcv>=2.0.0rc4"

2. 测试案例

下载代码和模型:

git clone https://github.com/open-mmlab/mmdetection.git
mkdir ./checkpoints
mim download mmdet --config rtmdet_tiny_8xb32-300e_coco --dest ./checkpoints

运行代码,核心是定义inferencer和使用inferencer进行推理两行:

from mmdet.apis import DetInferencer

# Choose to use a config
model_name = 'rtmdet_tiny_8xb32-300e_coco'
# Setup a checkpoint file to load
checkpoint = './checkpoints/rtmdet_tiny_8xb32-300e_coco_20220902_112414-78e30dcc.pth'

# Set the device to be used for evaluation
device = 'cpu'

# Initialize the DetInferencer
inferencer = DetInferencer(model_name, checkpoint, device)

# Use the detector to do inference
img = 'demo.jpg'
result = inferencer(img, out_dir='./output')

# Show the structure of result dict
from rich.pretty import pprint
pprint(result, max_length=4)

# Show the output image
from PIL import Image
Image.open('./output/vis/demo.jpg')

3. 自定义数据进行训练

3.1 准备数据

建议使用coco格式,参见https://cocodataset.org/#format-data。文件从头至尾按照顺序分为以下段落:

{
“info”: info,
“licenses”: [license],
“images”: [image],
“annotations”: [annotation],
“categories”: [category]
}
下面是从instances_val2017.json文件中摘出的一个annotation的实例,这里的segmentation就是polygon格式:

{
“segmentation”: [[510.66,423.01,511.72,420.03,510.45…]],
“area”: 702.1057499999998,
“iscrowd”: 0,
“image_id”: 289343,
“bbox”: [473.07,395.93,38.65,28.67],
“category_id”: 18,
“id”: 1768
},
从instances_val2017.json文件中摘出的2个category实例如下所示:

{
“supercategory”: “person”,
“id”: 1,
“name”: “person”
},
{
“supercategory”: “vehicle”,
“id”: 2,
“name”: “bicycle”
},

我们来看测试案例的例子,包含三个大字段,其中categories非常简单,只有一个balloon(我们需要训练的目标)
在这里插入图片描述
images则是如下的清单:
在这里插入图片描述
annotations如下:
在这里插入图片描述

3.2 配置config文件

config文件中需要定义数据,模型,训练参数,优化器等各种参数。测试案例如下:

config_balloon = """
# Inherit and overwrite part of the config based on this config
_base_ = './rtmdet_tiny_8xb32-300e_coco.py'

data_root = 'data/balloon/' # dataset root

train_batch_size_per_gpu = 4
train_num_workers = 2

max_epochs = 20
stage2_num_epochs = 1
base_lr = 0.00008


metainfo = {
    'classes': ('balloon', ),
    'palette': [
        (220, 20, 60),
    ]
}

train_dataloader = dict(
    batch_size=train_batch_size_per_gpu,
    num_workers=train_num_workers,
    dataset=dict(
        data_root=data_root,
        metainfo=metainfo,
        data_prefix=dict(img='train/'),
        ann_file='train.json'))

val_dataloader = dict(
    dataset=dict(
        data_root=data_root,
        metainfo=metainfo,
        data_prefix=dict(img='val/'),
        ann_file='val.json'))

test_dataloader = val_dataloader

val_evaluator = dict(ann_file=data_root + 'val.json')

test_evaluator = val_evaluator

model = dict(bbox_head=dict(num_classes=1))

# learning rate
param_scheduler = [
    dict(
        type='LinearLR',
        start_factor=1.0e-5,
        by_epoch=False,
        begin=0,
        end=10),
    dict(
        # use cosine lr from 10 to 20 epoch
        type='CosineAnnealingLR',
        eta_min=base_lr * 0.05,
        begin=max_epochs // 2,
        end=max_epochs,
        T_max=max_epochs // 2,
        by_epoch=True,
        convert_to_iter_based=True),
]

train_pipeline_stage2 = [
    dict(type='LoadImageFromFile', backend_args=None),
    dict(type='LoadAnnotations', with_bbox=True),
    dict(
        type='RandomResize',
        scale=(640, 640),
        ratio_range=(0.1, 2.0),
        keep_ratio=True),
    dict(type='RandomCrop', crop_size=(640, 640)),
    dict(type='YOLOXHSVRandomAug'),
    dict(type='RandomFlip', prob=0.5),
    dict(type='Pad', size=(640, 640), pad_val=dict(img=(114, 114, 114))),
    dict(type='PackDetInputs')
]

# optimizer
optim_wrapper = dict(
    _delete_=True,
    type='OptimWrapper',
    optimizer=dict(type='AdamW', lr=base_lr, weight_decay=0.05),
    paramwise_cfg=dict(
        norm_decay_mult=0, bias_decay_mult=0, bypass_duplicate=True))

default_hooks = dict(
    checkpoint=dict(
        interval=5,
        max_keep_ckpts=2,  # only keep latest 2 checkpoints
        save_best='auto'
    ),
    logger=dict(type='LoggerHook', interval=5))

custom_hooks = [
    dict(
        type='PipelineSwitchHook',
        switch_epoch=max_epochs - stage2_num_epochs,
        switch_pipeline=train_pipeline_stage2)
]

# load COCO pre-trained weight
load_from = './checkpoints/rtmdet_tiny_8xb32-300e_coco_20220902_112414-78e30dcc.pth'

train_cfg = dict(type='EpochBasedTrainLoop', max_epochs=max_epochs, val_interval=1)
visualizer = dict(vis_backends=[dict(type='LocalVisBackend'),dict(type='TensorboardVisBackend')])
"""

with open('../configs/rtmdet/rtmdet_tiny_1xb4-20e_balloon.py', 'w') as f:
    f.write(config_balloon)

3.3 开始训练

使用Mac M2芯片需要修改3个地方。首先是需要设置

export PYTORCH_ENABLE_MPS_FALLBACK=1

其次是mmcv中的nms需要转到cpu上计算,打开mmcv/ops/nms.py,将class NMSop(torch.autograd.Function)中的inds = ext_module.nms(bboxes, scores…)改为inds = ext_module.nms(bboxes.cpu(), scores.cpu()…)
运行后会出现一个assert报错,找到源代码,把那一行assert删掉即可。
运行完成后,可以查看tensorboard:

%load_ext tensorboard

# see curves in tensorboard
%tensorboard --logdir ./work_dirs

然后查看测试结果

from mmdet.apis import DetInferencer
import glob

# Choose to use a config
config = '../configs/rtmdet/rtmdet_tiny_1xb4-20e_balloon.py'
# Setup a checkpoint file to load
checkpoint = glob.glob('./work_dirs/rtmdet_tiny_1xb4-20e_balloon/best_coco*.pth')[0]

# Set the device to be used for evaluation
device = 'cpu'

# Initialize the DetInferencer
inferencer = DetInferencer(config, checkpoint, device)

# Use the detector to do inference
img = './data/balloon/val/4838031651_3e7b5ea5c7_b.jpg'
result = inferencer(img, out_dir='./output')
# Show the output image
Image.open('./output/vis/4838031651_3e7b5ea5c7_b.jpg')

在这里插入图片描述

4. 其他

MMYOLO:传统的目标检测库
MMRotate:旋转检测库
MMDetection3D:三维检测库
下面几期一一介绍。

本文来自互联网用户投稿,该文观点仅代表作者本人,不代表本站立场。本站仅提供信息存储空间服务,不拥有所有权,不承担相关法律责任。如若转载,请注明出处:http://www.mfbz.cn/a/163528.html

如若内容造成侵权/违法违规/事实不符,请联系我们进行投诉反馈qq邮箱809451989@qq.com,一经查实,立即删除!

相关文章

2023年【熔化焊接与热切割】考试试卷及熔化焊接与热切割试题及解析

题库来源:安全生产模拟考试一点通公众号小程序 熔化焊接与热切割考试试卷考前必练!安全生产模拟考试一点通每个月更新熔化焊接与热切割试题及解析题目及答案!多做几遍,其实通过熔化焊接与热切割模拟考试很简单。 1、【单选题】 对…

React+后端实现导出Excle表格的功能

最近在做一个基于Reactantd前端框架的Excel导出功能,我主要在后端做了处理,这个功能完成后,便总结成一篇技术分享文章,感兴趣的小伙伴可以参考该分享来做导出excle表格功能,以下步骤同样适用于vue框架,或者…

“轻松实现文件夹批量重命名:使用顺序编号批量改名“

你是否曾经遇到过需要批量重命名文件夹,却因为繁琐的手动操作而感到困扰?现在,我们为你带来了一款全新的工具——轻松实现文件夹批量重命名,使用顺序编号批量改名。这款工具将帮助你轻松解决文件夹重命名的问题,提高工…

SpringSecurity5|12.实现RememberMe 及 实现原理分析

security/day08 这个功能大家还熟悉么?我们在登录网站的时候,除了让你输入用户名和密码,还会有个勾选框: 记住我!!!不是让大家记住我哈。 值得一提的是,Spring Security 也提供了这个…

2023年汉字小达人市级比赛在线模拟题更新:40分钟150题完整对标

今天是2023年11月19日,距离11月30日的汉字小达人市级比赛还有11天。许多孩子正在利用难得的周末抓紧练习和备赛。 结合一些孩子的反馈和需求,我把150题的在线模拟题做了更新,增加了前面的个人信息填写的部分,并且把整个试卷的完成…

python自动化标注工具+自定义目标P图替换+深度学习大模型(代码+教程+告别手动标注)

省流建议 本文针对以下需求: 想自动化标注一些目标不再想使用yolo想在目标检测/语意分割有所建树计算机视觉项目想玩一玩大模型了解自动化工具了解最前沿模型自定义目标P图替换… 确定好需求,那么我们发车! 实现功能与结果 该模型将首先…

python——第九天

今日目标: 偏函数 递归 字符串对象 切片 常见排序和查找 偏函数: python中存在一种函数的特殊使用,称为偏函数 如果在调用某个函数时,恰好某一个或者,某一些参数都是一个固定值(正好不是默认值)…

Linux常用命令——bye命令

在线Linux命令查询工具 bye 命令用于中断FTP连线并结束程序。。 补充说明 bye命令在ftp模式下,输入bye即可中断目前的连线作业,并结束ftp的执行。 语法 bye实例 bye在线Linux命令查询工具

从0开始学习JavaScript--JavaScript 数字与日期

JavaScript中的数字和日期是处理数值计算和时间相关任务的核心。本文将深入研究JavaScript中数字的表示、常见运算,以及日期对象的创建、格式化等操作,并通过丰富的示例代码,可以更全面地了解和应用这些概念。 JavaScript数字基础 JavaScri…

【RocketMq系列-01】RocketMq安装和基本概念

RocketMq系列整体栏目 内容链接地址【一】RocketMq安装和基本概念https://zhenghuisheng.blog.csdn.net/article/details/134486709 RocketMq安装和基本概念 一,RocketMq安装和基本概念1,RocketMq基本安装(本地安装)2,Rocketmq的核心概念2.1&…

Linux使用ifconifg命令,没有显示ens33

Linux使用ifconifg命令,没有显示ens33 1.问题2.步骤2.1 查看虚拟机的组件是否启动了2.2 修改网络配置文件 ONBOOT修改为yes2.3 重启网络2.4 修改网络服务配置 3.解决 1.问题 打开虚拟机准备使用xshell连接时发现连接失败,在机器上查看ip发现ens33不现实…

Mysql主从搭建

Mysql主从搭建 1.Mysql下载1.1 查看操作系统2.2 下载mysql安装包 2.Mysql安装2.1 解压2.2 目录重命名2.3 创建data,存储文件2.4 创建用户组2.5 授权用户2.6 配置环境变量2.7 编辑my.cnf2.8 创建相关目录和文件2.9 初始化数据库2.10 复制mysql.server到/etc/init.d/下…

安卓环境搭建及运行安卓应用

1 jdk安装 安卓项目也是java开发的,运行在虚拟器上,安装jdk及运行的时候,就会带上虚拟器 jdk前面已经讲过,不在讲解 2 下载安装androj studio https://developer.android.google.cn/studio?hlzh-cn 下载下来,双击…

Shell脚本:Linux Shell脚本学习指南(第一部分Shell基础)一

你好,欢迎来到「Linux Shell脚本」学习专题,你将享受到免费的 Shell 编程资料,以及很棒的浏览体验。 这套 Shell 脚本学习指南针对初学者编写,它通俗易懂,深入浅出,不仅讲解了基本知识,还深入底…

数据采集与大数据架构分享

实现场景 要实现亿级数据的长期收集更新,并对采集后的数据进行整理和加工,用于人工智能的训练数据素材集。 数据采集 java支持的爬虫框架还是有很多的,如:webMagic、Spider、Jsoup等添加链接描述 pipeline处理管道 数据并发开发…

供应链|顶刊MSOM论文解读:服务竞争下的库存共享

问题背景 在汽车、玩具等行业中,零售商之间的库存共享变得十分常见。库存共享可以解决由需求不确定导致的库存错配问题。如果零售商之间同意共享库存,那么当需求较少、自身库存过剩时,可以将过剩库存卖给其他零售商;反之&#xf…

WinForms C# 导入和导出 CSV 文件 Spread.NET

使用 WinForms C# 和 VB.NET 导入和导出 CSV 文件 2023 年 11 月 17 日 使用 Spread.NET 直接在 .NET WinForms 应用程序中处理 CSV 文件。 Spread.NET可帮助您创建电子表格、网格、仪表板和表单。它包括一个强大的计算引擎,具有 450 多个函数以及导入和导出 Micros…

iptables详解:链、表、表链关系、规则的基本使用

目录 防火墙基本概念 什么是防火墙? Netfilter与iptables的关系 链的概念 表的概念 表链关系 规则的概念 查询规则 添加规则 删除iptables中的记录 修改规则 更详细的命令(5链4表) 防火墙基本概念 什么是防火墙? 在…

在VS Code中使用VIM

文章目录 安装和基本使用设置 安装和基本使用 VIM是VS Code的强大对手,其简化版本VI是Linux内置的文本编辑器,堪称VS Code问世之前最流行的编辑器,也是VS Code问世之后,我仍在使用的编辑器。 对VIM无法割舍的原因有二&#xff0…

Python中,我们可以使用pandas和numpy库对Excel数据进行预处理,包括读取数据、数据清洗、异常值剔除等

文章目录 一、什么是数据预处理二、对excel数据进行详细的数据预处理操作总结 一、什么是数据预处理 数据预处理是一种对数据进行清洗、整理、转换等操作的过程,旨在提高数据质量,使其适应模型的需求,从而改进数据挖掘或机器学习的结果。 数…
最新文章