用ResNet50+Qwen2-VL-2B-Instruct+LoRA模仿Diffusion-VLA的论文思路,在3090显卡上训练和测试成功
想一步步的实现Diffusion VLA论文的思路,不过论文的图像的输入用DINOv2进行特征提取的,我先把这个部分换成ResNet50。
老铁们,直接上代码:
from PIL import Image
import torch
import torchvision.models as models
from torch import nn
from datasets import Dataset
from modelscope import snapshot_download, AutoTokenizer
from swanlab.integration.transformers import SwanLabCallback
from qwen_vl_utils import process_vision_info
from peft import LoraConfig, TaskType, get_peft_model, PeftModel
from transformers import (TrainingArguments,Trainer,DataCollatorForSeq2Seq,Qwen2VLForConditionalGeneration,AutoProcessor,
)
import swanlab
import json
from torchvision import transforms
import matplotlib.pyplot as plt
import torch
import torch.nn as nn
import torch.nn.functional as F
import matplotlib.pyplot as plt
import torchvision.models as modelsclass CustomResNet(nn.Module):def __init__(self, output_size=(256, 1176)):super(CustomResNet, self).__init__()# 预训练的 ResNet 模型resnet = models.resnet50(pretrained=True)# 去掉 ResNet 的最后全连接层和池化层self.features = nn.Sequential(*list(resnet.children())[:-2]) # 去掉最后的FC层和AvgPool层# 自定义的卷积层,调整步幅和padding来控制尺寸self.conv1 = nn.Conv2d(2048, 2048, kernel_size=3, stride=1, padding=1) # 保持大小self.conv2 = nn.Conv2d(2048, 2048, kernel_size=3, stride=1, padding=1) # 保持大小self.conv3 = nn.Conv2d(2048, 2048, kernel_size=3, stride=1, padding=1) # 保持大小# 上采样层,用于增加特征图的尺寸self.upconv1 = nn.ConvTranspose2d(2048, 2048, kernel_size=4, stride=4, padding=0) # 上采样self.upconv2 = nn.ConvTranspose2d(2048, 2048, kernel_size=4, stride=4, padding=0) # 上采样# 最终卷积层将特征图变为单通道输出(灰度图)self.final_conv = nn.Conv2d(2048, 1, kernel_size=1) # 输出单通道def forward(self, x):# 获取ResNet的特征图x = self.features(x)# 经过卷积层x = self.conv1(x)x = self.conv2(x)x = self.conv3(x)# 上采样阶段:增加特征图的尺寸x = self.upconv1(x) # 上采样1x = self.upconv2(x) # 上采样2# 使用插值进行微调输出尺寸x = F.interpolate(x, size=(256, 1176), mode='bilinear', align_corners=False)# 通过最后的卷积层输出(单通道)x = self.final_conv(x) # 通过最后的卷积层输出return xdevice = torch.device("cuda" if torch.cuda.is_available() else "cpu")
# device = torch.device("cpu")# 创建模型并移动到设备上
model_ResNet = CustomResNet(output_size=(256, 1176)).to(device)# 定义图像预处理过程
image_transform = transforms.Compose([transforms.Resize((800, 800)), # 确保图像大小一致(通常为224x224)transforms.ToTensor(), # 转换为Tensor并标准化transforms.Normalize(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225]) # 标准化
])def extract_resnet_features(image_path):"""使用ResNet提取图像特征"""image = Image.open(image_path).convert("RGB") # 加载图像并转换为RGBimage_tensor = image_transform(image).unsqueeze(0).to('cuda') # 添加batch维度并转换为cuda Tensor# features = resnet_extractor(image_tensor) # 从ResNet提取特征 features = model_ResNet(image_tensor)return featuresdef process_func(example):"""将数据集进行预处理,加入ResNet特征提取"""MAX_LENGTH = 8192input_ids, attention_mask, labels = [], [], []conversation = example["conversations"]input_content = conversation[0]["value"]output_content = conversation[1]["value"]file_path = input_content.split("<|vision_start|>")[1].split("<|vision_end|>")[0] # 获取图像路径messages = [{"role": "user","content": [{"type": "image","image": f"{file_path}","resized_height": 224, # 确保图像尺寸为224x224"resized_width": 224,},{"type": "text", "text": "COCO Yes:"},],}]text = processor.apply_chat_template(messages, tokenize=False, add_generation_prompt=True) # 获取文本image_inputs, video_inputs = process_vision_info(messages) # 获取数据数据(预处理过)inputs = processor(text=[text],images=image_inputs,videos=video_inputs,padding=True,return_tensors="pt",)# print("inputs['pixel_values'] shape: ", inputs['pixel_values'].shape)# 提取图像特征image_tensor = extract_resnet_features(file_path) # 从图像路径提取特征# print("image_tensor shape: ", image_tensor.shape)inputs['pixel_values'] = image_tensor[0,0,:,:] # 替换图像特征为ResNet特征inputs = {key: value.tolist() for key, value in inputs.items()} # tensor -> list,为了方便拼接instruction = inputsresponse = tokenizer(f"{output_content}", add_special_tokens=False)input_ids = (instruction["input_ids"][0] + response["input_ids"] + [tokenizer.pad_token_id])attention_mask = instruction["attention_mask"][0] + response["attention_mask"] + [1]labels = ([-100] * len(instruction["input_ids"][0])+ response["input_ids"]+ [tokenizer.pad_token_id])if len(input_ids) > MAX_LENGTH: # 做一个截断input_ids = input_ids[:MAX_LENGTH]attention_mask = attention_mask[:MAX_LENGTH]labels = labels[:MAX_LENGTH]input_ids = torch.tensor(input_ids)attention_mask = torch.tensor(attention_mask)labels = torch.tensor(labels)inputs['pixel_values'] = torch.tensor(inputs['pixel_values'])inputs['image_grid_thw'] = torch.tensor(inputs['image_grid_thw']).squeeze(0) # 由(1,h,w)变换为(h,w)return {"input_ids": input_ids, "attention_mask": attention_mask, "labels": labels,"pixel_values": inputs['pixel_values'], "image_grid_thw": inputs['image_grid_thw']}def predict(messages, model):# 准备推理text = processor.apply_chat_template(messages, tokenize=False, add_generation_prompt=True)image_inputs, video_inputs = process_vision_info(messages)inputs = processor(text=[text],images=image_inputs,videos=video_inputs,padding=True,return_tensors="pt",)inputs = inputs.to("cuda")# 生成输出generated_ids = model.generate(**inputs, max_new_tokens=128)generated_ids_trimmed = [out_ids[len(in_ids) :] for in_ids, out_ids in zip(inputs.input_ids, generated_ids)]output_text = processor.batch_decode(generated_ids_trimmed, skip_special_tokens=True, clean_up_tokenization_spaces=False)return output_text[0]# 在modelscope上下载Qwen2-VL模型到本地目录下
model_dir = snapshot_download("Qwen/Qwen2-VL-2B-Instruct", cache_dir="./", revision="master")# 使用Transformers加载模型权重
tokenizer = AutoTokenizer.from_pretrained("./Qwen/Qwen2-VL-2B-Instruct/", use_fast=False, trust_remote_code=True)
processor = AutoProcessor.from_pretrained("./Qwen/Qwen2-VL-2B-Instruct")# 加载模型
model = Qwen2VLForConditionalGeneration.from_pretrained("./Qwen/Qwen2-VL-2B-Instruct/", device_map="cuda", torch_dtype=torch.bfloat16, trust_remote_code=True,)
model.enable_input_require_grads() # 开启梯度检查点时,要执行该方法
model.config.use_cache = False# 处理数据集:读取json文件
# 拆分成训练集和测试集,保存为data_vl_train.json和data_vl_test.json
train_json_path = "data_vl.json"
with open(train_json_path, 'r') as f:data = json.load(f)train_data = data[:-4]test_data = data[-4:]with open("data_vl_train.json", "w") as f:json.dump(train_data, f)with open("data_vl_test.json", "w") as f:json.dump(test_data, f)train_ds = Dataset.from_json("data_vl_train.json")
train_dataset = train_ds.map(process_func)# 配置LoRA
config = LoraConfig(task_type=TaskType.CAUSAL_LM,target_modules=["q_proj", "k_proj", "v_proj", "o_proj", "gate_proj", "up_proj", "down_proj"],inference_mode=False, # 训练模式r=4, #64, # Lora 秩lora_alpha= 1, #16, # Lora alaph,具体作用参见 Lora 原理lora_dropout=0.05, # Dropout 比例bias="none",
)# 获取LoRA模型
peft_model = get_peft_model(model, config)# 配置训练参数
args = TrainingArguments(output_dir="./output/Qwen2-VL-2B",per_device_train_batch_size=1,gradient_accumulation_steps=1,logging_steps=10,logging_first_step=5,num_train_epochs=2,save_steps=100,learning_rate=1e-4,save_on_each_node=True,gradient_checkpointing=True,report_to="none",
)# 配置Trainer
trainer = Trainer(model=peft_model,args=args,train_dataset=train_dataset,data_collator=DataCollatorForSeq2Seq(tokenizer=tokenizer, padding=True),
)# 开启模型训练
trainer.train()# ====================测试模式===================
# 配置测试参数
val_config = LoraConfig(task_type=TaskType.CAUSAL_LM,target_modules=["q_proj", "k_proj", "v_proj", "o_proj", "gate_proj", "up_proj", "down_proj"],inference_mode=True, # 训练模式r=4,#64, # Lora 秩lora_alpha=1,#16, # Lora alaph,具体作用参见 Lora 原理lora_dropout=0.05, # Dropout 比例bias="none",
)# 获取测试模型
val_peft_model = PeftModel.from_pretrained(model, model_id="./output/Qwen2-VL-2B/checkpoint-992", config=val_config)# 读取测试数据
with open("data_vl_test.json", "r") as f:test_dataset = json.load(f)test_image_list = []
for item in test_dataset:input_image_prompt = item["conversations"][0]["value"]# 去掉前后的<|vision_start|>和<|vision_end|>origin_image_path = input_image_prompt.split("<|vision_start|>")[1].split("<|vision_end|>")[0]messages = [{"role": "user", "content": [{"type": "image", "image": origin_image_path},{"type": "text","text": "COCO Yes:"}]}]response = predict(messages, val_peft_model)messages.append({"role": "assistant", "content": f"{response}"})print(messages[-1])test_image_list.append(swanlab.Image(origin_image_path, caption=response))
我在3090显卡(24G显存)运行的结果:

相关文章:
用ResNet50+Qwen2-VL-2B-Instruct+LoRA模仿Diffusion-VLA的论文思路,在3090显卡上训练和测试成功
想一步步的实现Diffusion VLA论文的思路,不过论文的图像的输入用DINOv2进行特征提取的,我先把这个部分换成ResNet50。 老铁们,直接上代码: from PIL import Image import torch import torchvision.models as models from torch…...
创建.net core 8.0项目时,有个启用原生AOT发布是什么意思
启用原生 AOT 发布(Native AOT publishing) 是指在 .NET 6 及更高版本中使用 Ahead-of-Time (AOT) 编译 技术,将应用程序提前编译为本地机器代码,从而生成更高效、更快速启动的可执行文件。 1. AOT 编译是什么? AOT …...
2.1.7-1 io_uring的使用
一、背景 (1)下面几个有关异步操作的例子: a)客户端和服务端的异步关系,就是客户端发送请求后不需要等待结果,接下来发送其他请求。 b)对于服务端,客户端来请求后,服务…...
群论学习笔记
什么是对称? 对称是一个保持对象结构不变的变换,对称是一个过程,而不是一个具体的事物,伽罗瓦的对称是对方程根的置换,而一个置换就是对一系列事物的重排方式,严格的说,它也并不是这个重排本身…...
深入解析-正则表达式
学习正则,我们到底要学什么? 正则表达式(RegEx)是一种强大的文本匹配工具,广泛应用于数据验证、文本搜索、替换和解析等领域。学习正则表达式,我们不仅要掌握其语法规则,还需要学会如何高效地利…...
yolov5核查数据标注漏报和误报
提示:文章写完后,目录可以自动生成,如何生成可参考右边的帮助文档 文章目录 前言一、误报二、漏报三、源码总结 前言 本文主要用于记录数据标注和模型预测之间的漏报和误报思想及其源码 提示:以下是本篇文章正文内容,…...
日志聚类算法 Drain 的实践与改良
在现实场景中,业务程序输出的日志往往规模庞大并且类型纷繁复杂。我们在查询和查看这些日志时,平铺的日志列表会让我们目不暇接,难以快速聚焦找到重要的日志条目。 在观测云中,我们在日志页面提供了聚类分析功能,可以…...
如何让用户在网页中填写PDF表格?
在网页中让用户直接填写PDF表格,可以大大简化填写、打印、扫描和提交表单的流程。通过使用复选框、按钮和列表等交互元素,PDF表格不仅让填写过程更高效,还能方便地在电脑或移动设备上访问和提交数据。 以下是在浏览器中显示可填写PDF表单的四…...
GXUOJ-算法-补题:22级《算法设计与分析》第一次课堂练习
2.最大子数组和 问题描述 代码解答 #include<bits/stdc.h> using namespace std; const int N1005; int sum,n,a[N]; int res-1;int result(){for(int i0;i<n;i){if(sum<0) suma[i];else{suma[i];resmax(res,sum);}}return res; } int main(){cin>>n;for(i…...
源代码编译安装X11及相关库、vim,配置vim(3)
一、vim插件安装 首先安装插件管理器Vundle ()。参照官网流程即可。vim的插件管理器有多个,只用Vundle就够了。然后~/.vimrc里写上要安装的插件: filetype offset rtp~/.vim/bundle/Vundle.vim call vundle#begin() Plugin VundleVim/Vundle.vim Plugin powerline…...
uniapp 微信小程序 自定义日历组件
效果图 功能:可以记录当天是否有某些任务或者某些记录 具体使用: 子组件代码 <template><view class"Accumulate"><view class"bx"><view class"bxx"><view class"plank"><…...
EdgeX规则引擎eKuiper
EdgeX 规则引擎eKuiper 一、架构设计 LF Edge eKuiper 是物联网数据分析和流式计算引擎。它是一个通用的边缘计算服务或中间件,为资源有限的边缘网关或设备而设计。 eKuiper 采用 Go 语言编写,其架构如下图所示: eKuiper 是 Golang 实现的轻量级物联网边缘分析、流式处理开源…...
react 优化方案
更详细的 React 优化方案可以分为性能优化、代码结构优化、开发效率提升等多个方面,结合实际项目需求,逐步应用这些优化策略。 一、性能优化 1. 避免不必要的重新渲染 React.memo: 缓存组件,防止组件在父组件重新渲染时无意义的重新渲染。 const ChildComponent = Reac…...
【Linux】sed编辑器
一、基本介绍 sed编辑器也叫流编辑器(stream editor),它是根据事先设计好得一组规则编辑数据流。 交互式文本编辑器(如Vim)中,可以用键盘命令交互式地插入、删除或替换文本数据。 sed编辑器是根据命令处理…...
(leetcode算法题)137. 只出现一次的数字 II
处理这种数据集中只有一个数出现的频次为1,其他数出现的频次均为k的题目 往往都是使用位运算的进行求解 假设 target在数据集中只出现了1次,其他数据n1, ... nj都出现了 k 次, 考虑数据集中所有数据的第 i 位的取值,那么将会有…...
在大数据环境下高效运用NoSQL与关系型数据库的结合策略
在大数据环境下,高效运用NoSQL与关系型数据库结合策略涉及到理解两者各自的优劣势,以及如何有效地整合它们。以下是一些代码示例和实际案例,以帮助你了解这种结合策略。 背景介绍 NoSQL数据库通常用于处理大量非结构化或半结构化的数据&…...
C语言——分支与循环语句
目录 一.分支语句 1.if语句 2.悬空else问题 3.switch语句 default子句 二.循环语句 1.while循环 whle循环流程图: break与continue 2.for循环 2.2for与while循环 2.3关于for循环的一道笔试题 3.do while 循环 三.猜数字游戏实现 四.goto语句 补充 …...
下载b站高清视频
需要使用的edge上的一个扩展插件,所以选择使用edge浏览器。 1、在edge浏览器上下载 强力视频下载合并 扩展插件 2、在edge上打开b站,登录自己账号(登录后才能下载到高清!!)。打开一个视频,选择自…...
常见 JVM垃圾回收器、内存分配策略、JVM调优
垃圾收集( Garbage Collection ,下文简称 GC),垃圾收集的历史远远比 Java久远。经过半个世纪的发展,今天的内存动态分配与内存回收技术已经相当成熟,一切看起来都进入了“自动化”时代,那为什么…...
【HarmonyOS应用开发——ArkTS语言】欢迎界面(启动加载页)的实现【合集】
目录 😋环境配置:华为HarmonyOS开发者 📺演示效果: 📖实验步骤及方法: 一、在media文件夹中添加想要使用的图片素材 二、在entry/src/main/ets/page目录下创建Welcome.ets文件 1. 整体结构与组件声…...
MPNet:旋转机械轻量化故障诊断模型详解python代码复现
目录 一、问题背景与挑战 二、MPNet核心架构 2.1 多分支特征融合模块(MBFM) 2.2 残差注意力金字塔模块(RAPM) 2.2.1 空间金字塔注意力(SPA) 2.2.2 金字塔残差块(PRBlock) 2.3 分类器设计 三、关键技术突破 3.1 多尺度特征融合 3.2 轻量化设计策略 3.3 抗噪声…...
深入剖析AI大模型:大模型时代的 Prompt 工程全解析
今天聊的内容,我认为是AI开发里面非常重要的内容。它在AI开发里无处不在,当你对 AI 助手说 "用李白的风格写一首关于人工智能的诗",或者让翻译模型 "将这段合同翻译成商务日语" 时,输入的这句话就是 Prompt。…...
【HarmonyOS 5.0】DevEco Testing:鸿蒙应用质量保障的终极武器
——全方位测试解决方案与代码实战 一、工具定位与核心能力 DevEco Testing是HarmonyOS官方推出的一体化测试平台,覆盖应用全生命周期测试需求,主要提供五大核心能力: 测试类型检测目标关键指标功能体验基…...
关于iview组件中使用 table , 绑定序号分页后序号从1开始的解决方案
问题描述:iview使用table 中type: "index",分页之后 ,索引还是从1开始,试过绑定后台返回数据的id, 这种方法可行,就是后台返回数据的每个页面id都不完全是按照从1开始的升序,因此百度了下,找到了…...
Frozen-Flask :将 Flask 应用“冻结”为静态文件
Frozen-Flask 是一个用于将 Flask 应用“冻结”为静态文件的 Python 扩展。它的核心用途是:将一个 Flask Web 应用生成成纯静态 HTML 文件,从而可以部署到静态网站托管服务上,如 GitHub Pages、Netlify 或任何支持静态文件的网站服务器。 &am…...
【C语言练习】080. 使用C语言实现简单的数据库操作
080. 使用C语言实现简单的数据库操作 080. 使用C语言实现简单的数据库操作使用原生APIODBC接口第三方库ORM框架文件模拟1. 安装SQLite2. 示例代码:使用SQLite创建数据库、表和插入数据3. 编译和运行4. 示例运行输出:5. 注意事项6. 总结080. 使用C语言实现简单的数据库操作 在…...
06 Deep learning神经网络编程基础 激活函数 --吴恩达
深度学习激活函数详解 一、核心作用 引入非线性:使神经网络可学习复杂模式控制输出范围:如Sigmoid将输出限制在(0,1)梯度传递:影响反向传播的稳定性二、常见类型及数学表达 Sigmoid σ ( x ) = 1 1 +...
什么?连接服务器也能可视化显示界面?:基于X11 Forwarding + CentOS + MobaXterm实战指南
文章目录 什么是X11?环境准备实战步骤1️⃣ 服务器端配置(CentOS)2️⃣ 客户端配置(MobaXterm)3️⃣ 验证X11 Forwarding4️⃣ 运行自定义GUI程序(Python示例)5️⃣ 成功效果