DETR整体模型结构解析

news/2024/6/20 15:53:55/文章来源:https://blog.csdn.net/weixin_43912994/article/details/139302553

DETR流程

  1. Backbone用卷积神经网络抽特征。最后通过一层1*1卷积转化到d_model维度fm(B,d_model,HW)。

  2. position embedding建立跟fm维度相同的位置编码(B,d_model,HW)。

  3. Transformer Encoder,V为fm,K,Q为fm+position embedding。因为V代表的是图像特征。所以不添加位置编码

  4. Transformer Decoder。生成一个固定大小(query_num)的object query(B,q_num,d_model)比如100个预测框。Decoder输入tgt与object query形状相同。代码中为torch.zero()。第一层selfattention K,V为tgt+query,Q为tgt。第二层Q为上一层输出+query。V为encoder输出,K为encoder输出+position。这里V仍然代表图像特征所以不添加位置编码

  5. 用输出的100个object query框和ground truth框做一个匹配,然后在一一配对好的框中去计算目标检测的loss(分类loss与回归loss(L1+IOU))

  6. 二分图匹配与匈牙利算法

    DETR 预测了一组固定大小的 N = 100 个边界框

    将 ground-truth 也扩展成 N = 100 个检测框

    使用一个额外的特殊类标签 ϕ 来表示在未检测到任何对象,或者认为是背景类别。

    这样预测和真实都是两个100 个元素的集合了

    采用匈牙利算法进行二分图匹配,对预测集合和真实集合的元素进行一一对应,使得匹配损失最小。

  7. 推理过程不需要二分图匹配,只需要取最大得分框即可

代码详细参考:

transformer 在 CV 中的应用(二) DETR 目标检测网络 -

网络结构

参数说明:B:batchsize大小,C通道数,H,W:CNN输出特征图的高宽。d_model设定的特征维度大小如512。
Q,K,V:自注意力矩阵。l_q:Q矩阵的长度,l_kv:K,V矩阵的长度。KV矩阵的长度必须相同,Q矩阵长度可以跟KV矩阵长度不同
Q矩阵维度:(B,l_q,d_model)
K矩阵维度:(B,l_kv,d_model)
V矩阵维度:(B,l_kv,d_model)
object_query维度(B,q_num,d_model)

Backbone:

img→CNNbackbone→fm特征图(B,C,H,W) → fm特征图输入到transformer中时要再经过一层卷积将通道数转化成d_ model。C→d_model.

position embedding(B,d_model,H*W)。backbone通过CNN提取图像特征,然后通过特征图生成尺度对应的位置编码。

position embedding:

位置编码官方实现了两种,一种是固定位置编码,另一种是自学习位置编码,这里就介绍固定位置编码。

位置编码要考虑 x, y 两个方向,图像中任意一个点 (h, w) 有一个位置,这个位置编码长度为 256 ,前 128 维代表 h 的位置编码, 后 128 维代表 w 的位置编码,把这两个 128 维的向量拼接起来就得到一个 256 维的向量,它代表 (h, w) 的位置编码。位置编码的计算公式如下图所示

在这里插入图片描述

Transformer
DETRtransformer结构图
在这里插入图片描述

接受CNN提取的特征(B,d_model,HW),位置编码(B,d_model,HW),querys(B,query_num,d_model)

encoder:q,k添加位置编码。v代表图像本身特征,不添加位置编码。multi_head_attention跟FFN后都带了两个残差连接。

# post代表norm放在后面
def forward_post(self,src,src_mask: Optional[Tensor] = None,src_key_padding_mask: Optional[Tensor] = None,pos: Optional[Tensor] = None):q = k = self.with_pos_embed(src, pos)  #q,k增加positionsrc2 = self.self_attn(q, k, value=src, attn_mask=src_mask,key_padding_mask=src_key_padding_mask)[0]src = src + self.dropout1(src2)   # 残差src = self.norm1(src)# ffnsrc2 = self.linear2(self.dropout(self.activation(self.linear1(src))))# 残差src = src + self.dropout2(src2)src = self.norm2(src)return src

decoder:

设定一个object queries(num_query,d_model)

有两层multihead self attention

  • 第一层obquery添加到K,Q上作为position embedding

第二层的Q来于decoder,K,V来自于encoder输出。

第二层self attention K添加编码,Q增加object queries。V代表图像特征,不添加任何信息

KV要有相同维度,Q可以跟KV在长度维度上不同,d_model维度相同

softmax(QKt/(d^0.5))V→矩阵乘法Q*Kt:(l_q,d_model)@(d_model_l_kv)→(l_q,l_kv)

再乘以V(l_q,l_kv)@(l_kv,d_model)→(l_q,d_model)

def forward_post(self, tgt, memory,tgt_mask: Optional[Tensor] = None,memory_mask: Optional[Tensor] = None,tgt_key_padding_mask: Optional[Tensor] = None,memory_key_padding_mask: Optional[Tensor] = None,pos: Optional[Tensor] = None,query_pos: Optional[Tensor] = None):''':param tgt: query_pos rep 2tensor of shape (bs, c, h, w) ->tgt = torch.zeros_like(query_embed):param memory::param tgt_mask::param memory_mask::param tgt_key_padding_mask::param memory_key_padding_mask::param pos::param query_pos::return:'''q = k = self.with_pos_embed(tgt, query_pos)tgt2 = self.self_attn(q, k, value=tgt, attn_mask=tgt_mask,key_padding_mask=tgt_key_padding_mask)[0]tgt = tgt + self.dropout1(tgt2)tgt = self.norm1(tgt)tgt2 = self.multihead_attn(query=self.with_pos_embed(tgt, query_pos),key=self.with_pos_embed(memory, pos),value=memory, attn_mask=memory_mask,key_padding_mask=memory_key_padding_mask)[0]tgt = tgt + self.dropout2(tgt2)tgt = self.norm2(tgt)tgt2 = self.linear2(self.dropout(self.activation(self.linear1(tgt))))tgt = tgt + self.dropout3(tgt2)tgt = self.norm3(tgt)return tgt
  • DETR在计算attention的时候没有使用masked attention,因为将特征图展开成一维以后,所有像素都可能是互相关联的,因此没必要规定mask。

  • object queries的转换过程:object queries是预定义的目标查询的个数,代码中默认为100。它的意义是:根据Encoder编码的特征,Decoder将100个查询转化成100个目标,即最终预测这100个目标的类别和bbox位置。最终预测得到的shape应该为[N, 100, C],N为Batch Num,100个目标,C为预测的100个目标的类别数+1(背景类)以及bbox位置(4个值)

得到预测结果以后,将object predictions和ground truth box之间通过匈牙利算法进行二分匹配:假如有K个目标,那么100个object predictions中就会有K个能够匹配到这K个ground truth,其他的都会和“no object”匹配成功,使其在理论上每个object query都有唯一匹配的目标,不会存在重叠,所以DETR不需要nms进行后处理。

匹配

匈牙利匹配算法

匈牙利匹配算法,二分图匹配算法

scipy.optimize.linear_sum_assignment(cost_matrix, maximize=False)
#cost_matrix 二分图开销矩阵

https://blog.csdn.net/CV_Autobot/article/details/129096035

https://blog.csdn.net/lemonxiaoxiao/article/details/108672039

query与gt匹配

transformer通过query输出n_q数量的bbox与对应分类置信度

真实框[gt1,gt2,…gtn]

每个bbox与gt之间有一个距离度量。

距离度量由三部分组成:真实类别的置信度得分+边界框的L1loss+边界框的IOUloss

通过匈牙利算法找出距离最小的query_bbox为gt对应的prebbox

loss训练

整体流程

pred输出→100(num_queries)class,100(num_queries)boxes

gt(tagert)→100class,100boxes(包含背景类)

pred,gt→计算相互loss,得到二分图成本矩阵,然后计算匈牙利匹配算法→return匹配上的classes与boxes

匹配成功的框→计算真正的class损失(),box回归损失(GLOUloss)。。

预测框与真实框的差异来自于两方面:1.二分图匹配时带来的差异。2.预测框与真实框之间的差异。

  • 分类损失:交叉熵损失,针对所有predictions。没有匹配到的querybbox应该分类为背景

  • 回归损失:bbox loss采用了L1 loss和giou loss,针对匹配成功的querybbox

  • cardinality 损失,对应函数是 loss_cardinality ; cardinality 损失是计算预测有物体的个数的绝对损失,值是为了记录,不参与反向传播

本文来自互联网用户投稿,该文观点仅代表作者本人,不代表本站立场。本站仅提供信息存储空间服务,不拥有所有权,不承担相关法律责任。如若转载,请注明出处:http://www.ldbm.cn/p/430505.html

如若内容造成侵权/违法违规/事实不符,请联系编程新知网进行投诉反馈email:809451989@qq.com,一经查实,立即删除!

相关文章

开源工具专题-04 Atlassian Crowd部署备份及迁移

开源工具专题-04 Atlassian Crowd部署备份及迁移 注: 本教程由羞涩梦整理同步发布,本人技术分享站点:blog.hukanfa.com转发本文请备注原文链接,本文内容整理日期:2024-05-29csdn 博客名称:五维空间-影子&…

PyMySQL连接池

背景 在用python写后端服务时候,需要与mysql数据库进行一些数据查询或者插入更新等操作。启动服务后接口运行一切正常, 隔了第二天去看服务日志就会报错,问题如下: pymysql.err.OperationalError: (2006, "MySQL server ha…

Vue——事件修饰符

文章目录 前言阻止默认事件 prevent阻止事件冒泡 stop 前言 在官方文档中对于事件修饰符有一个很好的说明,本篇文章主要记录验证测试的案例。 官方文档 事件修饰符 阻止默认事件 prevent 在js原生的语言中,可以根据标签本身的事件对象进行阻止默认事件…

Java(六)——抽象类与接口

文章目录 抽象类和接口抽象类抽象类的概念抽象类的语法抽象类的特性抽象类的意义 接口接口的概念接口的语法接口的特性接口的使用实现多个接口接口与多态接口间的继承抽象类和接口的区别 抽象类和接口 抽象类 抽象类的概念 Java使用类实例化对象来描述现实生活中的实体&…

算法与数据结构高手养成:朴素的贪心法(上)最优化策略

✨✨ 欢迎大家来访Srlua的博文(づ ̄3 ̄)づ╭❤~✨✨ 🌟🌟 欢迎各位亲爱的读者,感谢你们抽出宝贵的时间来阅读我的文章。 我是Srlua小谢,在这里我会分享我的知识和经验。&am…

SOL 交易机器人基本知识

有没有可以盈利的机器人? 是的,各行各业都有许多盈利机器人。在金融领域,交易机器人被广泛用于自动化投资策略并根据预定义的算法执行交易。这些机器人可以分析市场趋势并做出快速决策,从而可能带来可观的回报。同样,在…

基于广义极大极小凹惩罚的心电信号降噪方法(MATLAB R2021B)

凸优化是数学最优化的一个子领域,研究定义于凸集中的凸函数最小化问题。由于心电信号降噪的过程可以理解为求信号的稀疏近似解,因此基于凸优化和稀疏性表达的去噪方法可用于心电信号处理。在凸优化的数学模型中,惩罚项的选取对最终结果会产生…

04_前端三大件JS

文章目录 JavaScript1.JS的组成部分2.JS引入2.1 直接在head中通过一对script标签定义脚本代码2.2创建JS函数池文件,所有html文件共享调用 3.JS的数据类型和运算符4.分支结构5.循环结构6.JS函数的声明7.JS中自定义对象8.JS_JSON在客户端使用8.1JSON串格式8.2JSON在前…

(文章复现)分布式电源接入配电网承载力评估方法研究

参考文献: [1]郝文斌,孟志高,张勇,等.新型电力系统下多分布式电源接入配电网承载力评估方法研究[J].电力系统保护与控制,2023,51(14):23-33. 1.摘要 随着光伏和风电等多种分布式电源的接入,使得传统配电网的结构及其运行状态发生了较大改变。因此&…

Pi 母公司将开发情感 AI 商业机器人;Meta 科学家:Sora 不是视频生成唯一方向丨RTE 开发者日报 Vol.214

开发者朋友们大家好: 这里是 「RTE 开发者日报」 ,每天和大家一起看新闻、聊八卦。我们的社区编辑团队会整理分享 RTE(Real-Time Engagement) 领域内「有话题的新闻」、「有态度的观点」、「有意思的数据」、「有思考的文章」、「…

DiffBIR论文阅读笔记

这篇是董超老师通讯作者的一篇盲图像修复的论文,目前好像没看到发表在哪个会议期刊,应该是还在投,这个是arxiv版本,代码倒是开源了。本文所指的BIR并不是一个single模型对任何未知图像degradation都能处理,而是用同一个…

电脑显示由于找不到msvcr110.dll 无法继续执行如何处理?最简单的修复msvcr110.dll文件方法

电脑显示由于找不到msvcr110.dll 无法继续执行?当你看到这种提示的时候,请不要紧张,这种是属于dll文件丢失,解决起来还是比较简单的,下面会详细的列明多种找不到msvcr110.dll的解决方法。 一.找不到msvcr110.dll是怎么…

FPGA——eMMC验证

一.FPGA基础 1.FPGA烧录流程 (1) 加载流文件 —— bitfile (2) 烧录文件 —— cmm 二.MMC 1.基础知识 (1)jz4740、mmc、emmc、sd之间的关系? jz4740——处理器 mmc——存储卡标准 emmc——mmc基础上发展的高效存储解决方案 sd—— 三.eMMC和SD case验证 1.ca…

【网络安全】新的恶意软件:无文件恶意软件GhostHook正在广泛传播

文章目录 推荐阅读 一种新的恶意软件 GhostHook v1.0 正在一个网络犯罪论坛上迅速传播。这种创新的无文件浏览器恶意软件由 Native-One 黑客组织开发,具有独特的分发方式和多功能性,对各种平台和浏览器构成重大威胁。 GhostHook v1.0 支持 Windows、Andr…

【Unity AR开发插件】五、运行示例程序

专栏 本专栏将介绍如何使用这个支持热更的AR开发插件,快速地开发AR应用。 链接: Unity开发AR系列 热更数据制作:制作热更数据-AR图片识别场景 插件简介 通过热更技术实现动态地加载AR场景,简化了AR开发流程,让用户可…

这个橙子真的香!老司机徒手把玩香橙派Kunpeng Pro事后感言

水果派爱好者 作为一个水果派爱好者,我早在大学时期就购入了树莓派3B以及一堆传感器配件,把玩至今。 后来在工作的时候,由于不满足树莓派3B的性能,又购入了树莓派4B,玩得不亦乐乎,做出了很多因吹斯汀的小玩…

Python装饰器的应用

Python 中的装饰器是一种语法糖,可以在运行时,动态的给函数或类添加功能。装饰器本质上是一个函数,使用 函数名就是可实现绑定给函数的第二个功能 。它的作用就是在不修改被装饰对象源代码和调用方式的前提下为被装饰对象添加额外的功能。 …

微前端(无界)入门

主应用通过props给子应用传值 父子应用通过eventBus通信 通过路由同步实现记录子应用的路由状态 主应用 main.ts: import ./assets/main.cssimport { createApp } from vue import { createPinia } from pinia import WujieVue from wujie-vue3import App from ./App.vue impo…

【开源三方库】Aki:一行代码极简体验JSC++跨语言交互

一、简介 OpenAtom OpenHarmony(以下简称“OpenHarmony”)的前端开发语言是ArkTS,在TypeScript(简称TS)生态基础上做了进一步扩展,继承了TS的所有特性,是JavaScript(简称JS&#xf…

【好书分享第十三期】AI数据处理实战108招:ChatGPT+Excel+VBA

文章目录 一、内容介绍二、内页插图三、作者简介四、前言/序言五、目录 一、内容介绍 《AI数据处理实战108招:ChatGPTExcelVBA》通过7个专题内容、108个实用技巧,讲解了如何运用ChatGPT结合办公软件Excel和VBA代码实现AI办公智能化、高效化。随书附赠了…