当前位置: 首页 > news >正文

Transformer实战-系列教程4:Vision Transformer 源码解读2

🚩🚩🚩Transformer实战-系列教程总目录

有任何问题欢迎在下面留言
本篇文章的代码运行界面均在Pycharm中进行
本篇文章配套的代码资源已经上传

4、Embbeding类

self.embeddings = Embeddings(config, img_size=img_size)
class Embeddings(nn.Module):"""Construct the embeddings from patch, position embeddings."""def __init__(self, config, img_size, in_channels=3):super(Embeddings, self).__init__()self.hybrid = Noneimg_size = _pair(img_size)if config.patches.get("grid") is not None:grid_size = config.patches["grid"]patch_size = (img_size[0] // 16 // grid_size[0], img_size[1] // 16 // grid_size[1])n_patches = (img_size[0] // 16) * (img_size[1] // 16)self.hybrid = Trueelse:patch_size = _pair(config.patches["size"])n_patches = (img_size[0] // patch_size[0]) * (img_size[1] // patch_size[1])self.hybrid = Falseif self.hybrid:self.hybrid_model = ResNetV2(block_units=config.resnet.num_layers,width_factor=config.resnet.width_factor)in_channels = self.hybrid_model.width * 16self.patch_embeddings = Conv2d(in_channels=in_channels,out_channels=config.hidden_size,kernel_size=patch_size,stride=patch_size)self.position_embeddings = nn.Parameter(torch.zeros(1, n_patches+1, config.hidden_size))self.cls_token = nn.Parameter(torch.zeros(1, 1, config.hidden_size))self.dropout = Dropout(config.transformer["dropout_rate"])def forward(self, x):# print(x.shape)B = x.shape[0]cls_tokens = self.cls_token.expand(B, -1, -1)# print(cls_tokens.shape)if self.hybrid:x = self.hybrid_model(x)x = self.patch_embeddings(x)#Conv2d: Conv2d(3, 768, kernel_size=(16, 16), stride=(16, 16))# print(x.shape)x = x.flatten(2)# print(x.shape)x = x.transpose(-1, -2)# print(x.shape)x = torch.cat((cls_tokens, x), dim=1)# print(x.shape)embeddings = x + self.position_embeddings# print(embeddings.shape)embeddings = self.dropout(embeddings)# print(embeddings.shape)return embeddings

接上前面的debug模式,在构造模型部分一直步入到Embbeding类中:

  1. 构造函数,传入了图像大小224*224,通道数3,以及配置参数
  2. patch_size=[16,16],16*16的区域选出一份特征,这个参数自己定义
  3. n_patches,224224的图像能够切分出1616的格子数量,(224/16)(224/16)=1414=196个
  4. 196就是我们要定义的序列的长度了
  5. patch_embeddings,是一个二维卷积,输入通道为3,输出通道为768,卷积核为patch_size=1616,步长为1616,步长为1616就表明原本224224的图像卷积后的长宽就为14*14了
  6. position_embeddings,初始化参数全部为0 ,形状为[1,197,768],197=196+1,加一的原因是在Transformer模型中,通常会在序列的开始添加一个可学习的类标记(class token),它在训练过程中帮助模型捕获全局信息以用于分类任务。position_embeddings是用来记录位置信息的
  7. cls_token,初始化参数全部为0,形状为[1,1,768]
  8. 因为要涉及到全连接层,所以加上Dropout

5、Encoder类

self.encoder = Encoder(config, vis)
class Encoder(nn.Module):def __init__(self, config, vis):super(Encoder, self).__init__()self.vis = visself.layer = nn.ModuleList()self.encoder_norm = LayerNorm(config.hidden_size, eps=1e-6)for _ in range(config.transformer["num_layers"]):layer = Block(config, vis)self.layer.append(copy.deepcopy(layer))def forward(self, hidden_states):# print(hidden_states.shape)attn_weights = []for layer_block in self.layer:hidden_states, weights = layer_block(hidden_states)if self.vis:attn_weights.append(weights)encoded = self.encoder_norm(hidden_states)return encoded, attn_weights

接上前面的debug模式,在构造模型部分步入到Encoder类中:

  1. 构造函数传进配置参数
  2. vis,设置可视化
  3. layer,设置PyTorch的一个列表
  4. encoder_norm,LayerNorm,Batch Normalization是对Batch做归一化,LayerNorm对层
  5. 循环添加Block:循环config.transformer["num_layers"]次,每次都创建一个Block实例并添加到self.layer中。这里的Block是一个定义了Transformer编码器层的类,它包括自注意力机制和前馈网络。copy.deepcopy(layer)确保每次都是向ModuleList添加一个新的、独立的Block副本

之前ConvNet的任务中,都是使用Batch 做归一化,为什么Transformer是对Layer做归一化呢,Transformer是在NLP任务中提出来的,每一句话的单词个数都不一样,太长的阶段,短的补0,如果是对batch做归一化,长句子的后面一些地方要和短句子补0的地方做归一化,改用Layer归一化实现显著提升效果的情况。

相关文章:

Transformer实战-系列教程4:Vision Transformer 源码解读2

🚩🚩🚩Transformer实战-系列教程总目录 有任何问题欢迎在下面留言 本篇文章的代码运行界面均在Pycharm中进行 本篇文章配套的代码资源已经上传 4、Embbeding类 self.embeddings Embeddings(config, img_sizeimg_size) class Embeddings(nn.…...

cesium-水平测距

cesium测量两点间的距离 <template><div id"cesiumContainer" style"height: 100vh;"></div><div id"toolbar" style"position: fixed;top:20px;left:220px;"><el-breadcrumb><el-breadcrumb-item&…...

【Android-Compose】手势检测实现按下、单击、双击、长按事件,以及避免频繁单击事件的简单方法

目录&#xff1a; 1 不需要双击事件 规避频繁单击事件2 需要双击事件&#xff08;常规写法&#xff09;3 后记&#xff1a;不建议使用上面的代码自定义按钮 1 不需要双击事件 规避频繁单击事件 var firstClickTime by remember { mutableStateOf(System.currentTimeMillis()…...

AUTOSAR汽车电子嵌入式编程精讲300篇-基于神经网络的CAN总线负载率优化(续)

目录 3.3 SA 算法 3.3.1 SA 算法原理 3.3.2 基于 SA 算法 CAN 总线负载率优化分析...

python爬虫6—高性能异步爬虫

如果有多个URL等待我们爬取&#xff0c;我们通常是一次只能爬取一个&#xff0c;爬取效率低&#xff0c;异步爬虫可以提高爬取效率&#xff0c;可以一次多多个URL同时同时发起请求 异步爬虫方式&#xff1a; 一、多线程、多进程&#xff08;不建议&#xff09;&#xff1a;可以…...

日历功能——C语言

实现日历功能&#xff0c;输入年份月份&#xff0c;输出日历 #include<stdio.h>int leap_year(int year) {if(year % 4 0 && year % 100 ! 0 || year % 400 0){return 1;}else{return 0;} }int determine_year_month_day(int *day,int month,int year) {if(mo…...

GPIO中断

1.EXTI简介 EXTI是External Interrupt的缩写&#xff0c;指外部中断。在嵌入式系统中&#xff0c;外部中断是一种用于处理外部事件的机制。当外部事件发生时&#xff08;比如按下按钮、传感器信号变化等&#xff09;&#xff0c;外部中断可以立即打断正在执行的程序&#xff0…...

springboot完成一个线上图片存放地址+实现前后端上传图片+回显

1.路径 注意路径 2.代码&#xff1a;&#xff08;那个imagePath没什么用&#xff0c;懒的删了&#xff09;&#xff0c;注意你的本地文件夹要有图片&#xff0c;才可以在线上地址中打开查看 package com.xxx.common.config;import org.springframework.beans.factory.annotat…...

编程思维与生活琐事的内在关联及其应用价值

随着科技的日益普及和信息化时代的到来&#xff0c;编程作为一种现代技能&#xff0c;其影响已不再局限于专业领域&#xff0c;而是逐步渗透到人们的日常生活之中。探讨编程与生活琐事之间的关系&#xff0c;有助于我们更好地理解如何将技术智慧应用于日常管理&#xff0c;提升…...

OSPF排错

目录 实验拓扑图 实验要求 实验排错 故障一 故障现象 故障分析 故障解决 故障二 故障现象 故障分析 故障解决 故障三 故障现象 故障分析 故障解决 故障四 故障现象 故障分析 故障解决 故障五 故障现象 故障分析 故障解决 故障六 故障现象 故障分析 …...

day07-CSS高级

01-定位 作用&#xff1a;灵活的改变盒子在网页中的位置 实现&#xff1a; 1.定位模式&#xff1a;position 2.边偏移&#xff1a;设置盒子的位置 left right top bottom 相对定位 position: relative 特点&#xff1a; 不脱标&#xff0c;占用自己原来位置 显示模…...

05 MP之ActiveRecord模式+SimpleQuery

1. ActiveRecord ActiveRecord(活动记录&#xff0c;简称AR)&#xff0c;是一种领域模型模式&#xff0c;特点是一个模型类对应关系型数据库中的一个表&#xff0c;而模型类的一个实例对应表中的一行记录。 其目标是通过围绕一个数据对象, 进行全部的CRUD操作。 1.1 让实体类…...

git diff查看比对两次不同时间点提交的异同

git diff查看比对两次不同时间点提交的异同 用 git diff命令&#xff1a; git diff commit-id-1 commit-id-2 不同commit-id在不同的时间点提交产生&#xff0c;因为也可以认为git diff是比对两个不同时间点的代码异同。 git diff比较不同commit版本的代码文件异同_git diff c…...

基于muduo网络库开发服务器程序和CMake构建项目 笔记

跟着施磊老师做C项目&#xff0c;施磊老师_腾讯课堂 (qq.com) 一、基于muduo网络库开发服务器程序 组合TcpServer对象创建EventLoop事件循环对象的指针明确TcpServer构造函数需要什么参数,输出ChatServer的构造函数在当前服务器类的构造函数当中,注册处理连接的回调函数和处理…...

前端支持下载模板、导入数据、导出数据(excel格式)

前言 xlsx是由SheetJS开发的一个处理excel文件的npm库,适用于前端开发者实现下载模板、导入导出excel文件等需求&#xff0c;演示的项目的技术栈为vue3 elementPlus 一. 引入xlsx 安装xlsx npm install xlsx引入xlsx import * as XLSX from xlsx;二. 下载模板 const han…...

编译Faiss-gpu【InterMKL】C++ 按步骤操作 基本不会有问题的 python原理相同。

编译Faiss-gpu C++ 基本介绍 使用Faiss版本【1.7.4】 该项目依赖于BLAS 组件 OpenBLAS 和 IntelMKL BLAS 【官方支持】 IntelMKL 会比 OpenBLAS 快的多。 【来自官方结论】 本机环境 Cuda :11.1 Cuda-Driver: 515 InterMKL: 2021.2.0 Faiss :1.7.4 注意:faiss仅…...

conn.execute的用法详解

conn.execute的用法详解 大家好&#xff0c;我是免费搭建查券返利机器人赚佣金就用微赚淘客系统3.0的小编&#xff0c;也是冬天不穿秋裤&#xff0c;天冷也要风度的程序猿&#xff01;今天&#xff0c;我们将深入研究数据库连接中conn.execute的用法&#xff0c;解析它的功能、…...

GetBuffer() 与 ReleaseBuffer() 使用详解

GetBuffer() 与 ReleaseBuffer() 使用详解 大家好&#xff0c;我是免费搭建查券返利机器人赚佣金就用微赚淘客系统3.0的小编&#xff0c;也是冬天不穿秋裤&#xff0c;天冷也要风度的程序猿&#xff01;今天&#xff0c;我们将深入研究在编程中常用的GetBuffer()与ReleaseBuff…...

Flink CEP(基本概念)

Flink CEP 在Flink的学习过程中&#xff0c;我们已经掌握了从基本原理和核心层的DataStream API到底层的处理函数&#xff0c;再到应用层的Table API和SQL的各种手段&#xff0c;可以应对实际应用开发的各种需求。然而&#xff0c;在实际应用中&#xff0c;还有一类更为复…...

[AIGC] Spring Gateway与 nacos 简介

文章目录 Spring Gateway简介主要特性优点总结 Nacos简介主要特性优点总结 Spring Gateway 简介 Spring Gateway是一个基于Spring Framework的工具&#xff0c;用于构建和管理微服务架构中的网关。它提供了一种简单而灵活的方式来路由和过滤请求&#xff0c;以及在微服务之间…...

接口不是json的内容能用Jsonpath获取吗,如果不能,我们选用什么方法处理呢?

JsonPath 是一种专门用于查询和提取 JSON 数据的查询语言&#xff08;类似 XPath 用于 XML&#xff09;。以下是详细解答&#xff1a; ​JsonPath 的应用场景​ ​API 响应处理​&#xff1a;从 REST API 返回的 JSON 数据中提取特定字段。​配置文件解析​&#xff1a;读取 J…...

使用 Windows 完成 iOS 应用上架:Appuploader对比其他证书与上传方案

iOS 应用上架流程对很多开发者来说都是一道复杂关卡&#xff0c;特别是当你并不使用 Mac 电脑时。虽然 Apple 一直强调使用其原生工具链&#xff08;Xcode 和 Transporter&#xff09;&#xff0c;但现实是大量开发者正在寻找更灵活的替代方案。 今天我将从证书申请和 IPA 上传…...

Edge(Bing)自动领积分脚本部署——基于python和Selenium(附源码)

微软的 Microsoft Rewards 计划可以通过 Bing 搜索赚取积分&#xff0c;积分可以兑换礼品卡、游戏等。每天的搜索任务不多&#xff0c;我们可以用脚本自动完成&#xff0c;提高效率&#xff0c;解放双手。 本文将手把手教你如何部署一个自动刷积分脚本&#xff0c;并解释其背…...

StringRedisTemplete使用

StringRedisTemplate是Spring Data Redis提供的一个模板类&#xff0c;用于简化对Redis的操作。它特别适合处理字符串类型的数据&#xff0c;并且封装了一系列常用的Redis命令&#xff0c;使开发者能够以更简洁的方式进行Redis操作。本文将详细介绍 StringRedisTemplate的使用方…...

Elasticsearch 海量数据写入与高效文本检索实践指南

Elasticsearch 海量数据写入与高效文本检索实践指南 一、引言 在大数据时代&#xff0c;企业和组织面临着海量数据的存储与检索需求。Elasticsearch&#xff08;以下简称 ES&#xff09;作为一款基于 Lucene 的分布式搜索和分析引擎&#xff0c;凭借其高可扩展性、实时搜索和…...

《数据挖掘》- 房价数据分析

这里写目录标题 采用的技术1. Python编程语言2. 网络爬虫库技术点对比与区别项目技术栈的协同工作流程 代码解析1. 导入头文件2. 读取原始数据3. 清洗数据4. 数据分割4.1 统计房屋信息的分段数量4.2 将房屋信息拆分为独立列4.3 处理面积字段4.4 删除原始房屋信息列 5. 可视化分…...

C++之动态数组vector

Vector 一、什么是 std::vector&#xff1f;二、std::vector 的基本特性&#xff08;一&#xff09;动态扩展&#xff08;二&#xff09;随机访问&#xff08;三&#xff09;内存管理 三、std::vector 的基本操作&#xff08;一&#xff09;定义和初始化&#xff08;二&#xf…...

[论文阅读] 人工智能 | 大语言模型计划生成的新范式:基于过程挖掘的技能学习

#论文阅读# 大语言模型计划生成的新范式&#xff1a;基于过程挖掘的技能学习 论文信息 Skill Learning Using Process Mining for Large Language Model Plan Generation Andrei Cosmin Redis, Mohammadreza Fani Sani, Bahram Zarrin, Andrea Burattin Cite as: arXiv:2410.…...

32单片机——窗口看门狗

1、WWDG的简介 WWDG&#xff1a;Window watchdog&#xff0c;即窗口看门狗 窗口看门狗本质上是能产生系统复位信号和提前唤醒中断的递减计数器 WWDG产生复位信号的条件&#xff1a; &#xff08;1&#xff09;当递减计数器值从0x40减到0x3F时复位&#xff08;即T6位跳变到0&a…...

VR视频制作有哪些流程?

VR视频制作流程知识 VR视频制作&#xff0c;作为融合了创意与技术的复杂制作过程&#xff0c;涵盖从初步策划到最终呈现的多个环节。在这个过程中&#xff0c;我们可以结合众趣科技的产品&#xff0c;解析每一环节的实现与优化&#xff0c;揭示背后的奥秘。 VR视频制作有哪些…...