当前位置: 首页 > article >正文

手搓多模态-03 顶层和嵌入层的搭建

声明:本代码非原创,是博主跟着国外大佬的视频教程编写的,本博客主要为记录学习成果所用。

我们首先开始编写视觉模型这一部分这一部分主要功能接收一个batch图像将其转化上下文相关嵌入向量这一阶段我们需要做的事情以下这些

  • 编写一个全局通用视觉配置
  • 编写用户模型调用
  • 输入图像嵌入向量
  • 通过transformer编码图像嵌入向量进行编码使其上下文相关

我们一个个实现这些代码

视觉模型配置

视觉模型配置主要如下

class SiglipVisionConfig:def __init__(
			self,
			hidden_size=768,
			num_hidden_layers=12,
			num_attention_heads=12,
			intermediate_size=3072,
			num_channels=3,
			image_size=224,
			patch_size=16,
			attention_dropout=0.0,
			layer_norm_eps=1e-6,
			num_image_tokens: int = None,**kwargs):super().__init__(**kwargs)
		self.hidden_size = hidden_size ## embedding 的维度
		self.num_hidden_layers = num_hidden_layers ## 隐藏层的数量
		self.num_attention_heads = num_attention_heads ## 注意力头数量
		self.intermediate_size = intermediate_size ## 线性层的维度 
		self.num_channels = num_channels ##图像的RGB通道
		self.image_size = image_size ## 图像尺寸,任何图像size都会被缩放到这个尺寸
		self.patch_size = patch_size ## 每个patch的尺寸
		self.attention_dropout = attention_dropout ## 注意力层dropout
		self.layer_norm_eps = layer_norm_eps ## 层归一化epsilon
		self.num_image_tokens = num_image_tokens ## 图像token数量,它实际上是一个固定值

为了各个变量涵义更加浅显易懂博主增加中文注释

顶层模型

随后用户模型用户只需要一个batch图像传入调用forward函数即可返回这些图像上下文相关embeddings

代码如下

class SiglipVisionModel(nn.Module): ## 最顶层的视觉模型,它负责顶层的传入和编码的输出def __init__(self, config:SiglipVisionConfig):super().__init__()self.config = configself.vision_model = SiglipVisionTransformer(config)def forward(self, pixel_values) -> Tuple:# [Batch_size,Channels,Height,Width] ->  [Batch_size,Num_Patches(num_image_token),Embedding_size(Hidden_size)]return self.vision_model(pixel_values = pixel_values)	

其中输入形状 [ Batch_size, Channels, Height, Width ],对应一个图像batch各自RGB通道像素输出 [ Batch_size, Num_Patches, Embedding_size ], 对应各个图像每个分割Patch嵌入结果

模型拆分

我们需要内部一次视觉模型调用拆分两个模型各自的调用也就是拆分嵌入模型Transformer编码这里我们创建一个SiglipVisionTransformer将其分成两个模型调用

代码如下

class SiglipVisionTransformer(nn.Module): ##视觉模型的第二层,将模型的调用分为了图像嵌入模型和transformer编码器模型的调用def __init__(self, config:SiglipVisionConfig):super().__init__()self.config = configself.embed_dim = config.hidden_sizeself.embeddings = SiglipVisionEmbeddings(config) ## 负责将图像嵌入成向量self.encoder = SiglipEncoder(config) ## 负责将向量编码成注意力相关的向量self.post_layer_norm = nn.LayerNorm(embed_dim, eps=config.layer_norm_eps) ## 层归一化def forward(self, pixel_values:torch.Tensor) -> torch.Tensor:"""
		pixel_values: [Batch_size,Channels,Height,Width]"""## [ Batch_size,Channels,Height,Width] -> [Batch_size,Num_Patches,Embedding_size] 
		hidden_states = self.embeddings(pixel_values) ## 将图像嵌入成向量# [Batch_size,Num_Patches,Embedding_size] -> [Batch_size,Num_Patches,Embedding_size]
		last_hidden_state = self.encoder(hidden_states) ## 将向量编码成注意力相关的向量# [Batch_size,Num_Patches,Embedding_size] -> [Batch_size,Num_Patches,Embedding_size]
		last_hidden_state = self.post_layer_norm(last_hidden_state)return last_hidden_state

嵌入模型

嵌入模型初始图像像素初步转换patch编码向量list, 同时阶段我们使用位置编码位置编码形式多种这里我们采用自学习嵌入向量每个位置创建一个可以学习参数向量形成位置矩阵使用时候根据indices位置矩阵抽取对应位置向量即可

代码

class SiglipVisionEmbeddings(nn.Module):	def __init__(self, config:SiglipVisionConfig):
		self.config = config
		self.patch_size = config.patch_size
		self.image_size = config.image_size
		self.embed_dim = config.hidden_size		self.patch_embedding = nn.Conv2d(
			in_channels = config.num_channels,
			out_channels = self.embed_dim,
			kernel_size = self.patch_size,
			stride = self.patch_size,
			padding = 'valid', ##不加padding)		self.num_patches = (self.image_size // self.patch_size) ** 2 ## 图像的patch数量 (224 // 16) ** 2 = 196
		self.num_positions = self.num_patches		self.position_embeddings = nn.Embedding(self.num_positions, self.embed_dim)		self.register_buffer("position_ids",
			torch.arange(self.num_positions).expand((1, -1)), ## 这里expand是为了保持和patch_embeds的维度一致,以便可以直接与之相加
			persistent=False,)	def forward(self, pixel_values:torch.FloatTensor) -> torch.Tensor:"""
		pixel_values: [Batch_size,Channels,Height,Width]
		"""
		_ , _ , height, width = pixel_values.shape## 卷积,3通道转embedding_size通道
		patch_embeds = torch.FloatTensor(self.patch_embedding(pixel_values)) ## [Batch_size,Channel,Height,Width] -> [Batch_size,Embedding_size,Num_Patches_Height,Num_Patches_Width]## flatten
		patch_embeds = patch_embeds.flatten(2) # [Batch_size,Embedding_size,Num_Patches_Height,Num_Patches_Width] -> [Batch_size,Embedding_size,Num_Patches]## transpose
		patch_embeds = patch_embeds.transpose(1,2) ## [Batch_size,Embedding_size,Num_Patches] -> [Batch_size,Num_Patches,Embedding_size] ## positon_encoding
		patch_embeds = patch_embeds + self.position_embeddings(self.position_ids) ## [Batch_size,Num_Patches,Embedding_size]  自学习的位置编码return patch_embeds

上面卷积配置表示我们希望卷积结果patch_size * embedding_size维度为了方便大家理解

这里简单介绍一下torch卷积

pytorch2D卷积

卷积通过卷积图像特征提取出来卷积操作可以如图所示

卷积操作本质输入区域展平向量同时卷积核展平向量做一次内积得到输出位置

torch卷积公式:

这里N_i 第i个batchCout_j 是指j输出通道输出星号代表二维区域weight权重input区域做一次卷积

这里可以看到多出一个通道概念其实对于图像来说输入通道就是RGB通道输出通道你希望一个卷积的图像区域多少特征

用图展示如下

这里彩色方块是1*1的卷积核我们希望一个三个输入通道输入卷积得到三个输出通道输出这样对于每个通道conv2D都会为其生成三个卷积核每个通道卷积结果卷积核顺序对应相加比如第一个输出通道的结果等于三个输入通道各自第一个卷积核卷积结果进行相加得到

由此再来这个公式

j输出通道结果等于所有输入通道j输出通道卷积卷积结果相加加上一个偏置矩阵得到

相关文章:

手搓多模态-03 顶层和嵌入层的搭建

声明:本代码非原创,是博主跟着国外大佬的视频教程编写的,本博客主要为记录学习成果所用。 我们首先开始编写视觉模型这一部分,这一部分的主要功能是接收一个batch的图像,并将其转化为上下文相关的嵌入向量,…...

【经验分享】将qt的ui文件转换为py文件

🌟 嗨,我是命运之光! 🌍 2024,每日百字,记录时光,感谢有你一路同行。 🚀 携手启航,探索未知,激发潜能,每一步都意义非凡。 首先简单的设计一个U…...

常用的国内镜像源

常见的 pip 镜像源 阿里云镜像:https://mirrors.aliyun.com/pypi/simple/ 清华大学镜像:https://pypi.tuna.tsinghua.edu.cn/simple 中国科学技术大学镜像:https://pypi.mirrors.ustc.edu.cn/simple/ 豆瓣镜像:https://pypi.doub…...

探秘JVM内部

在我们编写Java代码,点击运行后,会发生什么事呢? 首先,Java源代码会经过Java编译器将其编译成字节码,放在.class文件中 然后这些字节码文件就会被加载到jvm中,然后jvm会读取这些文件,调用相关…...

在HarmonyOS NEXT 开发中,如何指定一个号码,拉起系统拨号页面

大家好,我是 V 哥。 《鸿蒙 HarmonyOS 开发之路 卷1 ArkTS篇》已经出版上市了哈,有需要的朋友可以关注一下,卷2应用开发篇也马上要出版了,V 哥正在紧锣密鼓的写鸿蒙开发实战卷3的教材,卷3主要以项目实战为主&#xff0…...

利用空间-运动-回波稀疏性进行5D图像重建,以实现自由呼吸状态下肝脏定量磁共振成像(MRI)的加速采集|文献速递--深度学习医疗AI最新文献

Title 题目 5D image reconstruction exploiting space-motion-echo sparsity foraccelerated free-breathing quantitative liver MRI 利用空间-运动-回波稀疏性进行5D图像重建,以实现自由呼吸状态下肝脏定量磁共振成像(MRI)的加速采集 …...

Qt5 Mac系统检查休眠

在开发跨平台应用程序时,有时候我们需要检测系统的状态,比如是否处于休眠或唤醒状态。Qt是一个强大的跨平台应用开发框架,支持多种操作系统,包括Windows、Linux、macOS等。在这个场景下,我们关注的是如何在Qt5.10中检测到系统是否休眠以及在Mac上实现这一功能。本文将深入…...

ZKmall开源商城B2B2C电商用户隐私信息保护策略:数据脱敏全链路实践

随着业务的不断拓展和用户规模的持续扩大,用户隐私信息的保护也面临着前所未有的挑战。下面将深入探讨ZKmall开源商城在数据脱敏方面的实践,以及针对B2B2C电商用户隐私信息的具体保护策略。 数据脱敏,又称数据去标识化或数据匿名化&#xff0…...

Media streaming mental map

Media streaming is a huge topic with a bunch of scattered technologies, protocols, and formats. You may feel like hearing fragments without seeing the big picture. Let’s build that mental map together — here’s a high-level overview that connects everyt…...

linux Gitkraken 破解

ubuntu 安装 Gitkraken 9.x Pro 版本_gitcracken.git-CSDN博客...

SSL证书颁发机构有哪些呢

证书颁发机构(Certificate Authority, CA)是负责签发和管理数字证书的权威机构,分为公共信任的 CA 和私有/内部 CA。以下是常见的公共信任的 CA 分类及代表机构: 1. 国际知名公共 CA(浏览器/操作系统默认信任&#xff…...

13_pandas可视化_seaborn

导入库 import numpy as np import pandas as pd # import matplotlib.pyplot as plt #交互环境中不需要导入 import seaborn as sns sns.set_context({figure.figsize:[8, 6]}) # 设置图大小 # 屏蔽警告 import warnings warnings.filterwarnings("ignore")关系图 …...

Pgvector的安装

Pgvector的安装 向量化数据的存储,可以为 PostgreSQL 安装 vector 扩展来存储向量化数据 注意:在安装vector扩展之前,请先安装Postgres数据库 vector 扩展的步骤 1、下载vs_BuildTools 下载地址: https://visualstudio.microso…...

如何在大型项目中组织和管理 Vue 3 Hooks?

众所周知,Vue Hooks(通常指 Composition API 中的功能)是 Vue 3 引入的一种代码组织方式,用于更灵活地组合和复用逻辑。但是在项目中大量使用这种写法该如何更好的搭建结构呢?以下是可供参考实践的案例。 一、Hooks 组织原则 单一职责每个 Hook 应专注于完成单一功能,避…...

Django接入 免费的 AI 大模型——讯飞星火(2025年4月最新!!!)

上文有介绍deepseek接入,但是需要 付费,虽然 sliconflow 可以白嫖 token,但是毕竟是有限的,本文将介绍一款完全免费的 API——讯飞星火 目录 接入讯飞星火(免费) 测试对话 接入Django 扩展建议 接入讯飞星火…...

路由器学习

路由器原理 可以理解成把不同的网络打通,实现通信的设备。比如家里的路由器,他就是把家里的内网和互联网(外网)打通。 分类 1.(按应用场景分类) 路由器分为家用的,企业级的,运营…...

Redis 连接:深入解析与优化实践

Redis 连接:深入解析与优化实践 引言 Redis 作为一款高性能的键值型数据库,广泛应用于缓存、会话存储、消息队列等领域。Redis 的连接管理是确保其性能和稳定性的关键。本文将深入探讨 Redis 连接的原理、配置、优化方法以及常见问题,帮助您更好地掌握 Redis 连接技术。 …...

UE5学习记录part14

第17节 enemy behavior 173 making enemies move: AI Pawn Navigation 按P查看体积 So its very important that our nav mesh bounds volume encompasses all of the area that wed like our 因此,我们的导航网格边界体积必须包含我们希望 AI to navigate in and …...

【中间件】使用ElasticSearch提供的RestClientAPI操作ES

一、简介 ElasticSearch提供了RestClient来操作ES&#xff0c;包括对数据的增删改查&#xff0c;可参照官方文档&#xff1a;Java High Level REST Client 二、使用步骤&#xff1a; 可参照官方文档操作 导包 <dependency><groupId>org.elasticsearch.client<…...

Docker的备份与恢复

一、两种基本方式 docker export / import 在服务器上导出容器docker export container_name > container_backup.tar这里使用 > 重定向时默认保存路径为当前运行命令的路径&#xff0c;可以自行指定绝对路径来保存&#xff0c;后续加载时也使用对应的路径即可。 恢复为…...

C++ string 对象的操作(三十五)

1. string 对象的常见操作 下面的表格列出了 string 类型最常用的一些操作以及它们的功能&#xff1a; 操作说明示例os << s将字符串对象 s 写入输出流 os&#xff0c;返回 os。std::cout << s;is >> s从输入流 is 中读取字符串赋给 s&#xff08;以空白分…...

DAPP实战篇:规划下我们的开发线路

前言 在DApp实战篇&#xff1a;先用前端起个项目一文中我们起了一个前端项目&#xff0c;在后续开发中笔者将带领大家一步步完成这个DAPP&#xff0c;为了方便后续讲解&#xff0c;本篇将完整说明后续我们要进行的开发和思路。 主打前端 实际上一个完整的DAPP是由前端和智能…...

[leetcode] 面试经典 150 题——篇9:二叉树(番外:二叉树的遍历方式)

二叉树的遍历是指按照某种顺序访问二叉树中的每个节点。常见的遍历方式有四种&#xff1a;前序遍历&#xff08;Pre-order Traversal&#xff09;、中序遍历&#xff08;In-order Traversal&#xff09;、后序遍历&#xff08;Post-order Traversal&#xff09;以及层序遍历&am…...

【Elasticsearch】开启大数据分析的探索与预处理之旅

&#x1f9d1; 博主简介&#xff1a;CSDN博客专家&#xff0c;历代文学网&#xff08;PC端可以访问&#xff1a;https://literature.sinhy.com/#/literature?__c1000&#xff0c;移动端可微信小程序搜索“历代文学”&#xff09;总架构师&#xff0c;15年工作经验&#xff0c;…...

状态机思想编程练习

状态机实现LED流水灯 本次实验&#xff0c;我们将利用状态机的思想来进行Verilog编程实现一个LED流水灯&#xff0c;并通过Modelsim来进行模拟仿真&#xff0c;再到DE2-115开发板上进行验证。 ​ 首先进行主要代码的编写。 module led (input sys_clk,input sys_…...

C#:接口(interface)

目录 接口的核心是什么&#xff1f; 1. 什么是接口&#xff08;Interface&#xff09;&#xff0c;为什么要用它&#xff1f; 2. 如何定义和使用接口&#xff1f; 3.什么是引用接口&#xff1f; 如何“引用接口”&#xff1f; “引用接口”的关键点 4. 接口与抽象类的区…...

前端新增数据,但数据库里没有新增的数据

先看情况&#xff1a; 1.前端&#xff0c;可以进行删查改&#xff0c;但是新增数据之后&#xff0c;显示保存成功&#xff0c;也增加了空白的一行&#xff0c;但是数据没有显示出来。 2.后端接收到了数据&#xff0c;但返回结果的列表里面是空的&#xff1b;同时数据库里面没…...

Go语言的测试框架

Go语言测试框架详解 Go语言&#xff08;Golang&#xff09;自发布以来&#xff0c;因其简洁、高效和并发支持而受到广泛欢迎。在软件开发过程中&#xff0c;测试是确保代码质量与稳定性的重要环节。Go语言内置的测试框架为开发者提供了灵活而强大的测试工具&#xff0c;使得编…...

堆结构——面试算法题高频汇总

目录 引言 堆创建&增删改 堆构造过程 举个例子 堆插入元素 删除元素 在数组中找第k大的元素 举例 堆排序原理 合并k个排序链表 数据流中位数问题 引言 堆是将一组数据按照完全二叉树的存储顺序&#xff0c;将数据存储在一个一维数组中的结构。堆有两种结构&…...

httpx模块的使用

在使用requests模块发起请求时&#xff0c;报以下错误&#xff0c;表示服务器有可能使用的是http2.0协议版本&#xff0c;导致requests无法爬取。 此时就可以使用httpx模块爬取。 先下载httpx模块&#xff1a; pip install httpx[http2]然后用httpx发起请求&#xff1a; impo…...