当前位置: 首页 > news >正文

用ResNet50+Qwen2-VL-2B-Instruct+LoRA模仿Diffusion-VLA的论文思路,在3090显卡上训练和测试成功

想一步步的实现Diffusion VLA论文的思路,不过论文的图像的输入用DINOv2进行特征提取的,我先把这个部分换成ResNet50。

老铁们,直接上代码:

from PIL import Image
import torch
import torchvision.models as models
from torch import nn
from datasets import Dataset
from modelscope import snapshot_download, AutoTokenizer
from swanlab.integration.transformers import SwanLabCallback
from qwen_vl_utils import process_vision_info
from peft import LoraConfig, TaskType, get_peft_model, PeftModel
from transformers import (TrainingArguments,Trainer,DataCollatorForSeq2Seq,Qwen2VLForConditionalGeneration,AutoProcessor,
)
import swanlab
import json
from torchvision import transforms
import matplotlib.pyplot as plt
import torch
import torch.nn as nn
import torch.nn.functional as F
import matplotlib.pyplot as plt
import torchvision.models as modelsclass CustomResNet(nn.Module):def __init__(self, output_size=(256, 1176)):super(CustomResNet, self).__init__()# 预训练的 ResNet 模型resnet = models.resnet50(pretrained=True)# 去掉 ResNet 的最后全连接层和池化层self.features = nn.Sequential(*list(resnet.children())[:-2])  # 去掉最后的FC层和AvgPool层# 自定义的卷积层,调整步幅和padding来控制尺寸self.conv1 = nn.Conv2d(2048, 2048, kernel_size=3, stride=1, padding=1)  # 保持大小self.conv2 = nn.Conv2d(2048, 2048, kernel_size=3, stride=1, padding=1)  # 保持大小self.conv3 = nn.Conv2d(2048, 2048, kernel_size=3, stride=1, padding=1)  # 保持大小# 上采样层,用于增加特征图的尺寸self.upconv1 = nn.ConvTranspose2d(2048, 2048, kernel_size=4, stride=4, padding=0)  # 上采样self.upconv2 = nn.ConvTranspose2d(2048, 2048, kernel_size=4, stride=4, padding=0)  # 上采样# 最终卷积层将特征图变为单通道输出(灰度图)self.final_conv = nn.Conv2d(2048, 1, kernel_size=1)  # 输出单通道def forward(self, x):# 获取ResNet的特征图x = self.features(x)# 经过卷积层x = self.conv1(x)x = self.conv2(x)x = self.conv3(x)# 上采样阶段:增加特征图的尺寸x = self.upconv1(x)  # 上采样1x = self.upconv2(x)  # 上采样2# 使用插值进行微调输出尺寸x = F.interpolate(x, size=(256, 1176), mode='bilinear', align_corners=False)# 通过最后的卷积层输出(单通道)x = self.final_conv(x)  # 通过最后的卷积层输出return xdevice = torch.device("cuda" if torch.cuda.is_available() else "cpu")
# device = torch.device("cpu")# 创建模型并移动到设备上
model_ResNet = CustomResNet(output_size=(256, 1176)).to(device)# 定义图像预处理过程
image_transform = transforms.Compose([transforms.Resize((800, 800)),  # 确保图像大小一致(通常为224x224)transforms.ToTensor(),  # 转换为Tensor并标准化transforms.Normalize(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225])  # 标准化
])def extract_resnet_features(image_path):"""使用ResNet提取图像特征"""image = Image.open(image_path).convert("RGB")  # 加载图像并转换为RGBimage_tensor = image_transform(image).unsqueeze(0).to('cuda')  # 添加batch维度并转换为cuda Tensor# features = resnet_extractor(image_tensor)  # 从ResNet提取特征    features = model_ResNet(image_tensor)return featuresdef process_func(example):"""将数据集进行预处理,加入ResNet特征提取"""MAX_LENGTH = 8192input_ids, attention_mask, labels = [], [], []conversation = example["conversations"]input_content = conversation[0]["value"]output_content = conversation[1]["value"]file_path = input_content.split("<|vision_start|>")[1].split("<|vision_end|>")[0]  # 获取图像路径messages = [{"role": "user","content": [{"type": "image","image": f"{file_path}","resized_height": 224,  # 确保图像尺寸为224x224"resized_width": 224,},{"type": "text", "text": "COCO Yes:"},],}]text = processor.apply_chat_template(messages, tokenize=False, add_generation_prompt=True)  # 获取文本image_inputs, video_inputs = process_vision_info(messages)  # 获取数据数据(预处理过)inputs = processor(text=[text],images=image_inputs,videos=video_inputs,padding=True,return_tensors="pt",)# print("inputs['pixel_values'] shape: ", inputs['pixel_values'].shape)# 提取图像特征image_tensor = extract_resnet_features(file_path)  # 从图像路径提取特征# print("image_tensor shape: ", image_tensor.shape)inputs['pixel_values'] = image_tensor[0,0,:,:]  # 替换图像特征为ResNet特征inputs = {key: value.tolist() for key, value in inputs.items()}  # tensor -> list,为了方便拼接instruction = inputsresponse = tokenizer(f"{output_content}", add_special_tokens=False)input_ids = (instruction["input_ids"][0] + response["input_ids"] + [tokenizer.pad_token_id])attention_mask = instruction["attention_mask"][0] + response["attention_mask"] + [1]labels = ([-100] * len(instruction["input_ids"][0])+ response["input_ids"]+ [tokenizer.pad_token_id])if len(input_ids) > MAX_LENGTH:  # 做一个截断input_ids = input_ids[:MAX_LENGTH]attention_mask = attention_mask[:MAX_LENGTH]labels = labels[:MAX_LENGTH]input_ids = torch.tensor(input_ids)attention_mask = torch.tensor(attention_mask)labels = torch.tensor(labels)inputs['pixel_values'] = torch.tensor(inputs['pixel_values'])inputs['image_grid_thw'] = torch.tensor(inputs['image_grid_thw']).squeeze(0)  # 由(1,h,w)变换为(h,w)return {"input_ids": input_ids, "attention_mask": attention_mask, "labels": labels,"pixel_values": inputs['pixel_values'], "image_grid_thw": inputs['image_grid_thw']}def predict(messages, model):# 准备推理text = processor.apply_chat_template(messages, tokenize=False, add_generation_prompt=True)image_inputs, video_inputs = process_vision_info(messages)inputs = processor(text=[text],images=image_inputs,videos=video_inputs,padding=True,return_tensors="pt",)inputs = inputs.to("cuda")# 生成输出generated_ids = model.generate(**inputs, max_new_tokens=128)generated_ids_trimmed = [out_ids[len(in_ids) :] for in_ids, out_ids in zip(inputs.input_ids, generated_ids)]output_text = processor.batch_decode(generated_ids_trimmed, skip_special_tokens=True, clean_up_tokenization_spaces=False)return output_text[0]# 在modelscope上下载Qwen2-VL模型到本地目录下
model_dir = snapshot_download("Qwen/Qwen2-VL-2B-Instruct", cache_dir="./", revision="master")# 使用Transformers加载模型权重
tokenizer = AutoTokenizer.from_pretrained("./Qwen/Qwen2-VL-2B-Instruct/", use_fast=False, trust_remote_code=True)
processor = AutoProcessor.from_pretrained("./Qwen/Qwen2-VL-2B-Instruct")# 加载模型
model = Qwen2VLForConditionalGeneration.from_pretrained("./Qwen/Qwen2-VL-2B-Instruct/", device_map="cuda", torch_dtype=torch.bfloat16, trust_remote_code=True,)
model.enable_input_require_grads()  # 开启梯度检查点时,要执行该方法
model.config.use_cache = False# 处理数据集:读取json文件
# 拆分成训练集和测试集,保存为data_vl_train.json和data_vl_test.json
train_json_path = "data_vl.json"
with open(train_json_path, 'r') as f:data = json.load(f)train_data = data[:-4]test_data = data[-4:]with open("data_vl_train.json", "w") as f:json.dump(train_data, f)with open("data_vl_test.json", "w") as f:json.dump(test_data, f)train_ds = Dataset.from_json("data_vl_train.json")
train_dataset = train_ds.map(process_func)# 配置LoRA
config = LoraConfig(task_type=TaskType.CAUSAL_LM,target_modules=["q_proj", "k_proj", "v_proj", "o_proj", "gate_proj", "up_proj", "down_proj"],inference_mode=False,  # 训练模式r=4, #64,  # Lora 秩lora_alpha= 1, #16,  # Lora alaph,具体作用参见 Lora 原理lora_dropout=0.05,  # Dropout 比例bias="none",
)# 获取LoRA模型
peft_model = get_peft_model(model, config)# 配置训练参数
args = TrainingArguments(output_dir="./output/Qwen2-VL-2B",per_device_train_batch_size=1,gradient_accumulation_steps=1,logging_steps=10,logging_first_step=5,num_train_epochs=2,save_steps=100,learning_rate=1e-4,save_on_each_node=True,gradient_checkpointing=True,report_to="none",
)# 配置Trainer
trainer = Trainer(model=peft_model,args=args,train_dataset=train_dataset,data_collator=DataCollatorForSeq2Seq(tokenizer=tokenizer, padding=True),
)# 开启模型训练
trainer.train()# ====================测试模式===================
# 配置测试参数
val_config = LoraConfig(task_type=TaskType.CAUSAL_LM,target_modules=["q_proj", "k_proj", "v_proj", "o_proj", "gate_proj", "up_proj", "down_proj"],inference_mode=True,  # 训练模式r=4,#64,  # Lora 秩lora_alpha=1,#16,  # Lora alaph,具体作用参见 Lora 原理lora_dropout=0.05,  # Dropout 比例bias="none",
)# 获取测试模型
val_peft_model = PeftModel.from_pretrained(model, model_id="./output/Qwen2-VL-2B/checkpoint-992", config=val_config)# 读取测试数据
with open("data_vl_test.json", "r") as f:test_dataset = json.load(f)test_image_list = []
for item in test_dataset:input_image_prompt = item["conversations"][0]["value"]# 去掉前后的<|vision_start|>和<|vision_end|>origin_image_path = input_image_prompt.split("<|vision_start|>")[1].split("<|vision_end|>")[0]messages = [{"role": "user", "content": [{"type": "image", "image": origin_image_path},{"type": "text","text": "COCO Yes:"}]}]response = predict(messages, val_peft_model)messages.append({"role": "assistant", "content": f"{response}"})print(messages[-1])test_image_list.append(swanlab.Image(origin_image_path, caption=response))

我在3090显卡(24G显存)运行的结果:

相关文章:

用ResNet50+Qwen2-VL-2B-Instruct+LoRA模仿Diffusion-VLA的论文思路,在3090显卡上训练和测试成功

想一步步的实现Diffusion VLA论文的思路&#xff0c;不过论文的图像的输入用DINOv2进行特征提取的&#xff0c;我先把这个部分换成ResNet50。 老铁们&#xff0c;直接上代码&#xff1a; from PIL import Image import torch import torchvision.models as models from torch…...

创建.net core 8.0项目时,有个启用原生AOT发布是什么意思

启用原生 AOT 发布&#xff08;Native AOT publishing&#xff09; 是指在 .NET 6 及更高版本中使用 Ahead-of-Time (AOT) 编译 技术&#xff0c;将应用程序提前编译为本地机器代码&#xff0c;从而生成更高效、更快速启动的可执行文件。 1. AOT 编译是什么&#xff1f; AOT …...

2.1.7-1 io_uring的使用

一、背景 &#xff08;1&#xff09;下面几个有关异步操作的例子&#xff1a; a&#xff09;客户端和服务端的异步关系&#xff0c;就是客户端发送请求后不需要等待结果&#xff0c;接下来发送其他请求。 b&#xff09;对于服务端&#xff0c;客户端来请求后&#xff0c;服务…...

群论学习笔记

什么是对称&#xff1f; 对称是一个保持对象结构不变的变换&#xff0c;对称是一个过程&#xff0c;而不是一个具体的事物&#xff0c;伽罗瓦的对称是对方程根的置换&#xff0c;而一个置换就是对一系列事物的重排方式&#xff0c;严格的说&#xff0c;它也并不是这个重排本身…...

深入解析-正则表达式

学习正则&#xff0c;我们到底要学什么&#xff1f; 正则表达式&#xff08;RegEx&#xff09;是一种强大的文本匹配工具&#xff0c;广泛应用于数据验证、文本搜索、替换和解析等领域。学习正则表达式&#xff0c;我们不仅要掌握其语法规则&#xff0c;还需要学会如何高效地利…...

yolov5核查数据标注漏报和误报

提示&#xff1a;文章写完后&#xff0c;目录可以自动生成&#xff0c;如何生成可参考右边的帮助文档 文章目录 前言一、误报二、漏报三、源码总结 前言 本文主要用于记录数据标注和模型预测之间的漏报和误报思想及其源码 提示&#xff1a;以下是本篇文章正文内容&#xff0c;…...

日志聚类算法 Drain 的实践与改良

在现实场景中&#xff0c;业务程序输出的日志往往规模庞大并且类型纷繁复杂。我们在查询和查看这些日志时&#xff0c;平铺的日志列表会让我们目不暇接&#xff0c;难以快速聚焦找到重要的日志条目。 在观测云中&#xff0c;我们在日志页面提供了聚类分析功能&#xff0c;可以…...

如何让用户在网页中填写PDF表格?

在网页中让用户直接填写PDF表格&#xff0c;可以大大简化填写、打印、扫描和提交表单的流程。通过使用复选框、按钮和列表等交互元素&#xff0c;PDF表格不仅让填写过程更高效&#xff0c;还能方便地在电脑或移动设备上访问和提交数据。 以下是在浏览器中显示可填写PDF表单的四…...

GXUOJ-算法-补题:22级《算法设计与分析》第一次课堂练习

2.最大子数组和 问题描述 代码解答 #include<bits/stdc.h> using namespace std; const int N1005; int sum,n,a[N]; int res-1;int result(){for(int i0;i<n;i){if(sum<0) suma[i];else{suma[i];resmax(res,sum);}}return res; } int main(){cin>>n;for(i…...

源代码编译安装X11及相关库、vim,配置vim(3)

一、vim插件安装 首先安装插件管理器Vundle ()。参照官网流程即可。vim的插件管理器有多个&#xff0c;只用Vundle就够了。然后~/.vimrc里写上要安装的插件: filetype offset rtp~/.vim/bundle/Vundle.vim call vundle#begin() Plugin VundleVim/Vundle.vim Plugin powerline…...

uniapp 微信小程序 自定义日历组件

效果图 功能&#xff1a;可以记录当天是否有某些任务或者某些记录 具体使用&#xff1a; 子组件代码 <template><view class"Accumulate"><view class"bx"><view class"bxx"><view class"plank"><…...

EdgeX规则引擎eKuiper

EdgeX 规则引擎eKuiper 一、架构设计 LF Edge eKuiper 是物联网数据分析和流式计算引擎。它是一个通用的边缘计算服务或中间件,为资源有限的边缘网关或设备而设计。 eKuiper 采用 Go 语言编写,其架构如下图所示: eKuiper 是 Golang 实现的轻量级物联网边缘分析、流式处理开源…...

react 优化方案

更详细的 React 优化方案可以分为性能优化、代码结构优化、开发效率提升等多个方面,结合实际项目需求,逐步应用这些优化策略。 一、性能优化 1. 避免不必要的重新渲染 React.memo: 缓存组件,防止组件在父组件重新渲染时无意义的重新渲染。 const ChildComponent = Reac…...

【Linux】sed编辑器

一、基本介绍 sed编辑器也叫流编辑器&#xff08;stream editor&#xff09;&#xff0c;它是根据事先设计好得一组规则编辑数据流。 交互式文本编辑器&#xff08;如Vim&#xff09;中&#xff0c;可以用键盘命令交互式地插入、删除或替换文本数据。 sed编辑器是根据命令处理…...

(leetcode算法题)137. 只出现一次的数字 II

处理这种数据集中只有一个数出现的频次为1&#xff0c;其他数出现的频次均为k的题目 往往都是使用位运算的进行求解 假设 target在数据集中只出现了1次&#xff0c;其他数据n1, ... nj都出现了 k 次&#xff0c; 考虑数据集中所有数据的第 i 位的取值&#xff0c;那么将会有…...

在大数据环境下高效运用NoSQL与关系型数据库的结合策略

在大数据环境下&#xff0c;高效运用NoSQL与关系型数据库结合策略涉及到理解两者各自的优劣势&#xff0c;以及如何有效地整合它们。以下是一些代码示例和实际案例&#xff0c;以帮助你了解这种结合策略。 背景介绍 NoSQL数据库通常用于处理大量非结构化或半结构化的数据&…...

C语言——分支与循环语句

目录 一.分支语句 1.if语句 2.悬空else问题 3.switch语句 default子句 二.循环语句 1.while循环 whle循环流程图&#xff1a; break与continue 2.for循环 2.2for与while循环 2.3关于for循环的一道笔试题 3.do while 循环 三.猜数字游戏实现 四.goto语句 补充 …...

下载b站高清视频

需要使用的edge上的一个扩展插件&#xff0c;所以选择使用edge浏览器。 1、在edge浏览器上下载 强力视频下载合并 扩展插件 2、在edge上打开b站&#xff0c;登录自己账号&#xff08;登录后才能下载到高清&#xff01;&#xff01;&#xff09;。打开一个视频&#xff0c;选择自…...

常见 JVM垃圾回收器、内存分配策略、JVM调优

垃圾收集&#xff08; Garbage Collection &#xff0c;下文简称 GC&#xff09;&#xff0c;垃圾收集的历史远远比 Java久远。经过半个世纪的发展&#xff0c;今天的内存动态分配与内存回收技术已经相当成熟&#xff0c;一切看起来都进入了“自动化”时代&#xff0c;那为什么…...

【HarmonyOS应用开发——ArkTS语言】欢迎界面(启动加载页)的实现【合集】

目录 &#x1f60b;环境配置&#xff1a;华为HarmonyOS开发者 &#x1f4fa;演示效果&#xff1a; &#x1f4d6;实验步骤及方法&#xff1a; 一、在media文件夹中添加想要使用的图片素材​ 二、在entry/src/main/ets/page目录下创建Welcome.ets文件 1. 整体结构与组件声…...

【杂谈】-递归进化:人工智能的自我改进与监管挑战

递归进化&#xff1a;人工智能的自我改进与监管挑战 文章目录 递归进化&#xff1a;人工智能的自我改进与监管挑战1、自我改进型人工智能的崛起2、人工智能如何挑战人类监管&#xff1f;3、确保人工智能受控的策略4、人类在人工智能发展中的角色5、平衡自主性与控制力6、总结与…...

【kafka】Golang实现分布式Masscan任务调度系统

要求&#xff1a; 输出两个程序&#xff0c;一个命令行程序&#xff08;命令行参数用flag&#xff09;和一个服务端程序。 命令行程序支持通过命令行参数配置下发IP或IP段、端口、扫描带宽&#xff0c;然后将消息推送到kafka里面。 服务端程序&#xff1a; 从kafka消费者接收…...

【算法训练营Day07】字符串part1

文章目录 反转字符串反转字符串II替换数字 反转字符串 题目链接&#xff1a;344. 反转字符串 双指针法&#xff0c;两个指针的元素直接调转即可 class Solution {public void reverseString(char[] s) {int head 0;int end s.length - 1;while(head < end) {char temp …...

EtherNet/IP转DeviceNet协议网关详解

一&#xff0c;设备主要功能 疆鸿智能JH-DVN-EIP本产品是自主研发的一款EtherNet/IP从站功能的通讯网关。该产品主要功能是连接DeviceNet总线和EtherNet/IP网络&#xff0c;本网关连接到EtherNet/IP总线中做为从站使用&#xff0c;连接到DeviceNet总线中做为从站使用。 在自动…...

【JavaSE】绘图与事件入门学习笔记

-Java绘图坐标体系 坐标体系-介绍 坐标原点位于左上角&#xff0c;以像素为单位。 在Java坐标系中,第一个是x坐标,表示当前位置为水平方向&#xff0c;距离坐标原点x个像素;第二个是y坐标&#xff0c;表示当前位置为垂直方向&#xff0c;距离坐标原点y个像素。 坐标体系-像素 …...

OPENCV形态学基础之二腐蚀

一.腐蚀的原理 (图1) 数学表达式&#xff1a;dst(x,y) erode(src(x,y)) min(x,y)src(xx,yy) 腐蚀也是图像形态学的基本功能之一&#xff0c;腐蚀跟膨胀属于反向操作&#xff0c;膨胀是把图像图像变大&#xff0c;而腐蚀就是把图像变小。腐蚀后的图像变小变暗淡。 腐蚀…...

Android第十三次面试总结(四大 组件基础)

Activity生命周期和四大启动模式详解 一、Activity 生命周期 Activity 的生命周期由一系列回调方法组成&#xff0c;用于管理其创建、可见性、焦点和销毁过程。以下是核心方法及其调用时机&#xff1a; ​onCreate()​​ ​调用时机​&#xff1a;Activity 首次创建时调用。​…...

华为OD机考-机房布局

import java.util.*;public class DemoTest5 {public static void main(String[] args) {Scanner in new Scanner(System.in);// 注意 hasNext 和 hasNextLine 的区别while (in.hasNextLine()) { // 注意 while 处理多个 caseSystem.out.println(solve(in.nextLine()));}}priv…...

mac 安装homebrew (nvm 及git)

mac 安装nvm 及git 万恶之源 mac 安装这些东西离不开Xcode。及homebrew 一、先说安装git步骤 通用&#xff1a; 方法一&#xff1a;使用 Homebrew 安装 Git&#xff08;推荐&#xff09; 步骤如下&#xff1a;打开终端&#xff08;Terminal.app&#xff09; 1.安装 Homebrew…...

苹果AI眼镜:从“工具”到“社交姿态”的范式革命——重新定义AI交互入口的未来机会

在2025年的AI硬件浪潮中,苹果AI眼镜(Apple Glasses)正在引发一场关于“人机交互形态”的深度思考。它并非简单地替代AirPods或Apple Watch,而是开辟了一个全新的、日常可接受的AI入口。其核心价值不在于功能的堆叠,而在于如何通过形态设计打破社交壁垒,成为用户“全天佩戴…...