当前位置: 首页 > news >正文

Qwen2-VL微调体验

1.配置环境

2.数据集准备

3.模型下载

4.注册SwanLab

5.微调

6.训练过程可视化


1.配置环境

本博客使用的是2B模型,所以仅用了单卡3090,若大一点的模型,自行根据实际情况准备显卡

安装Python>=3.8

安装Qwen2-VL必要的库

pip install modelscope==1.18.0

pip install transformers==4.46.2

pip install accelerate==1.1.1

pip install datasets==2.18.0

pip install peft==0.13.2

pip install qwen-vl-utils==0.0.8

2.数据集准备

本博客使用的是coco_2014_caption数据集中的部分,json格式如下:

数据集下载可以使用如下同样会产生一个csv:


from modelscope.msdatasets import MsDataset
import os
import pandas as pdMAX_DATA_NUMBER = 500if not os.path.exists('coco_2014_caption'):# 从modelscope下载COCO 2014图像描述数据集ds =  MsDataset.load('modelscope/coco_2014_caption', subset_name='coco_2014_caption', split='train')print(len(ds))total = min(MAX_DATA_NUMBER, len(ds))os.makedirs('coco_2014_caption', exist_ok=True)image_paths = []captions = []for i in range(total):# 获取每个样本的信息item = ds[i]image_id = item['image_id']caption = item['caption']image = item['image']# 保存图片并记录路径image_path = os.path.abspath(f'coco_2014_caption/{image_id}.jpg')image.save(image_path)# 将路径和描述添加到列表中image_paths.append(image_path)captions.append(caption)# 每处理50张图片打印一次进度if (i + 1) % 50 == 0:print(f'Processing {i+1}/{total} images ({(i+1)/total*100:.1f}%)')# 将图片路径和描述保存为CSV文件df = pd.DataFrame({'image_path': image_paths,'caption': captions})# 将数据保存为CSV文件df.to_csv('./coco-2024-dataset.csv', index=False)print(f'数据处理完成,共处理了{total}张图片')else:print('coco_2014_caption目录已存在,跳过数据处理步骤')

我们需要得到json,所以执行下面脚本得到:

import pandas as pd
import json# 载入CSV文件
df = pd.read_csv('./coco-2024-dataset.csv')
conversations = []# 添加对话数据
for i in range(len(df)):conversations.append({"id": f"identity_{i+1}","conversations": [{"from": "user","value": f"COCO Yes: <|vision_start|>{df.iloc[i]['image_path']}<|vision_end|>"},{"from": "assistant", "value": df.iloc[i]['caption']}]})# 保存为Json
with open('data_vl.json', 'w', encoding='utf-8') as f:json.dump(conversations, f, ensure_ascii=False, indent=2)

以上则是数据集准备,若使用自定义数据集则根据上面的格式准备!!!

3.模型下载

本博客使用魔塔社区中的Qwen2-VL-2B模型

from modelscope import snapshot_download, AutoTokenizer
from transformers import TrainingArguments, Trainer, DataCollatorForSeq2Seq, Qwen2VLForConditionalGeneration, AutoProcessor
import torchmodel_dir = snapshot_download("Qwen/Qwen2-VL-2B-Instruct", cache_dir="./", revision="master")
tokenizer = AutoTokenizer.from_pretrained("./Qwen/Qwen2-VL-2B-Instruct/", use_fast=False, trust_remote_code=True)
model = Qwen2VLForConditionalGeneration.from_pretrained("./Qwen/Qwen2-VL-2B-Instruct/", device_map="auto", torch_dtype=torch.bfloat16, trust_remote_code=True,)
model.enable_input_require_grads()  

4.注册SwanLab

为了微调时随时查看各项指标,注册SwanLab,复制key,粘贴到运行过程中,如下图:

5.微调

微调前保证当前文件夹下面包含以下:

train.py代码如下:

import torch
from datasets import Dataset
from modelscope import snapshot_download, AutoTokenizer
from swanlab.integration.transformers import SwanLabCallback
from qwen_vl_utils import process_vision_info
from peft import LoraConfig, TaskType, get_peft_model, PeftModel
from transformers import (TrainingArguments,Trainer,DataCollatorForSeq2Seq,Qwen2VLForConditionalGeneration,AutoProcessor,
)
import swanlab
import jsondef process_func(example):"""将数据集进行预处理"""MAX_LENGTH = 8192input_ids, attention_mask, labels = [], [], []conversation = example["conversations"]input_content = conversation[0]["value"]output_content = conversation[1]["value"]file_path = input_content.split("<|vision_start|>")[1].split("<|vision_end|>")[0]  # 获取图像路径messages = [{"role": "user","content": [{"type": "image","image": f"{file_path}","resized_height": 280,"resized_width": 280,},{"type": "text", "text": "COCO Yes:"},],}]text = processor.apply_chat_template(messages, tokenize=False, add_generation_prompt=True)  # 获取文本image_inputs, video_inputs = process_vision_info(messages)  # 获取数据数据(预处理过)inputs = processor(text=[text],images=image_inputs,videos=video_inputs,padding=True,return_tensors="pt",)inputs = {key: value.tolist() for key, value in inputs.items()} #tensor -> list,为了方便拼接instruction = inputsresponse = tokenizer(f"{output_content}", add_special_tokens=False)input_ids = (instruction["input_ids"][0] + response["input_ids"] + [tokenizer.pad_token_id])attention_mask = instruction["attention_mask"][0] + response["attention_mask"] + [1]labels = ([-100] * len(instruction["input_ids"][0])+ response["input_ids"]+ [tokenizer.pad_token_id])if len(input_ids) > MAX_LENGTH:  # 做一个截断input_ids = input_ids[:MAX_LENGTH]attention_mask = attention_mask[:MAX_LENGTH]labels = labels[:MAX_LENGTH]input_ids = torch.tensor(input_ids)attention_mask = torch.tensor(attention_mask)labels = torch.tensor(labels)inputs['pixel_values'] = torch.tensor(inputs['pixel_values'])inputs['image_grid_thw'] = torch.tensor(inputs['image_grid_thw']).squeeze(0)  #由(1,h,w)变换为(h,w)return {"input_ids": input_ids, "attention_mask": attention_mask, "labels": labels,"pixel_values": inputs['pixel_values'], "image_grid_thw": inputs['image_grid_thw']}def predict(messages, model):# 准备推理text = processor.apply_chat_template(messages, tokenize=False, add_generation_prompt=True)image_inputs, video_inputs = process_vision_info(messages)inputs = processor(text=[text],images=image_inputs,videos=video_inputs,padding=True,return_tensors="pt",)inputs = inputs.to("cuda")# 生成输出generated_ids = model.generate(**inputs, max_new_tokens=128)generated_ids_trimmed = [out_ids[len(in_ids) :] for in_ids, out_ids in zip(inputs.input_ids, generated_ids)]output_text = processor.batch_decode(generated_ids_trimmed, skip_special_tokens=True, clean_up_tokenization_spaces=False)return output_text[0]# 在modelscope上下载Qwen2-VL模型到本地目录下
model_dir = snapshot_download("Qwen/Qwen2-VL-2B-Instruct", cache_dir="./", revision="master")# 使用Transformers加载模型权重
tokenizer = AutoTokenizer.from_pretrained("./Qwen/Qwen2-VL-2B-Instruct/", use_fast=False, trust_remote_code=True)
processor = AutoProcessor.from_pretrained("./Qwen/Qwen2-VL-2B-Instruct")model = Qwen2VLForConditionalGeneration.from_pretrained("./Qwen/Qwen2-VL-2B-Instruct/", device_map="auto", torch_dtype=torch.bfloat16, trust_remote_code=True,)
model.enable_input_require_grads()  # 开启梯度检查点时,要执行该方法# 处理数据集:读取json文件
# 拆分成训练集和测试集,保存为data_vl_train.json和data_vl_test.json
train_json_path = "data_vl.json"
with open(train_json_path, 'r') as f:data = json.load(f)train_data = data[:-4]test_data = data[-4:]with open("data_vl_train.json", "w") as f:json.dump(train_data, f)with open("data_vl_test.json", "w") as f:json.dump(test_data, f)train_ds = Dataset.from_json("data_vl_train.json")
train_dataset = train_ds.map(process_func)# 配置LoRA
config = LoraConfig(task_type=TaskType.CAUSAL_LM,target_modules=["q_proj", "k_proj", "v_proj", "o_proj", "gate_proj", "up_proj", "down_proj"],inference_mode=False,  # 训练模式r=64,  # Lora 秩lora_alpha=16,  # Lora alaph,具体作用参见 Lora 原理lora_dropout=0.05,  # Dropout 比例bias="none",
)# 获取LoRA模型
peft_model = get_peft_model(model, config)# 配置训练参数
args = TrainingArguments(output_dir="./output/Qwen2-VL-2B",per_device_train_batch_size=4,gradient_accumulation_steps=4,logging_steps=10,logging_first_step=5,num_train_epochs=2,save_steps=100,learning_rate=1e-4,save_on_each_node=True,gradient_checkpointing=True,report_to="none",
)# 设置SwanLab回调
swanlab_callback = SwanLabCallback(project="Qwen2-VL-finetune",experiment_name="qwen2-vl-coco2014",config={"model": "https://modelscope.cn/models/Qwen/Qwen2-VL-2B-Instruct","dataset": "https://modelscope.cn/datasets/modelscope/coco_2014_caption/quickstart","github": "https://github.com/datawhalechina/self-llm","prompt": "COCO Yes: ","train_data_number": len(train_data),"lora_rank": 64,"lora_alpha": 16,"lora_dropout": 0.1,},
)# 配置Trainer
trainer = Trainer(model=peft_model,args=args,train_dataset=train_dataset,data_collator=DataCollatorForSeq2Seq(tokenizer=tokenizer, padding=True),callbacks=[swanlab_callback],
)# 开启模型训练
trainer.train()# 配置测试参数
val_config = LoraConfig(task_type=TaskType.CAUSAL_LM,target_modules=["q_proj", "k_proj", "v_proj", "o_proj", "gate_proj", "up_proj", "down_proj"],inference_mode=True,  # 训练模式r=64,  # Lora 秩lora_alpha=16,  # Lora alaph,具体作用参见 Lora 原理lora_dropout=0.05,  # Dropout 比例bias="none",
)# 获取测试模型
val_peft_model = PeftModel.from_pretrained(model, model_id="./output/Qwen2-VL-2B/checkpoint-62", config=val_config)# 读取测试数据
with open("data_vl_test.json", "r") as f:test_dataset = json.load(f)test_image_list = []
for item in test_dataset:input_image_prompt = item["conversations"][0]["value"]# 去掉前后的<|vision_start|>和<|vision_end|>origin_image_path = input_image_prompt.split("<|vision_start|>")[1].split("<|vision_end|>")[0]messages = [{"role": "user", "content": [{"type": "image", "image": origin_image_path},{"type": "text","text": "COCO Yes:"}]}]response = predict(messages, val_peft_model)messages.append({"role": "assistant", "content": f"{response}"})print(messages[-1])test_image_list.append(swanlab.Image(origin_image_path, caption=response))swanlab.log({"Prediction": test_image_list})swanlab.finish()

6.训练过程可视化

相关文章:

Qwen2-VL微调体验

1.配置环境 2.数据集准备 3.模型下载 4.注册SwanLab 5.微调 6.训练过程可视化 1.配置环境 本博客使用的是2B模型&#xff0c;所以仅用了单卡3090&#xff0c;若大一点的模型&#xff0c;自行根据实际情况准备显卡 安装Python>3.8 安装Qwen2-VL必要的库 pip install…...

论文的模拟环境和实验环境

模拟环境和实验环境 在撰写SCI计算机领域论文时,模拟环境和实验环境是两个重要的概念,它们之间存在显著的差异。 模拟环境主要是利用计算机、数学方法等手段对实际系统进行描述和分析的过程。在计算机科学中,模拟环境可以用于模拟各种算法、系统或网络的行为,以便在不需要…...

MySQL EXPLAIN 详解:一眼看懂查询计划

在日常的数据库开发中&#xff0c;我们经常需要分析 SQL 查询性能&#xff0c;而 EXPLAIN 是 MySQL 提供的利器&#xff0c;可以帮我们快速理解查询计划&#xff0c;优化慢查询。本文将详细解析 EXPLAIN 的输出字段及其含义&#xff0c;并结合实际案例分享优化思路。 一、什么是…...

自动呼入机器人如何与人工客服进行无缝切换?

自动呼入机器人如何与人工客服进行无缝切换&#xff1f; 原作者&#xff1a;开源呼叫中心FreeIPCC&#xff0c;其Github&#xff1a;https://github.com/lihaiya/freeipcc 自动呼入机器人与人工客服的无缝切换详解 自动呼入机器人与人工客服之间的无缝切换是确保客户体验连续…...

二分类模型的性能评价指标

1. 混淆矩阵 (Confusion Matrix) 预测正类预测负类实际正类 (P)True Positive (TP)False Negative (FN)实际负类 (N)False Positive (FP)True Negative (TN) True Positive (TP): 模型正确预测为正类的样本数。True Negative (TN): 模型正确预测为负类的样本数。False Positi…...

鸿蒙操作系统简介

华为鸿蒙系统&#xff08;HUAWEI HarmonyOS&#xff09;&#xff0c;是华为公司于2019年8月9日在东莞举行的华为开发者大会&#xff08;HDC.2019&#xff09;上正式发布的面向全场景的分布式操作系统&#xff0c;可以创造一个超级虚拟终端互联的世界&#xff0c;将人、设备、场…...

单片机:实现蜂鸣器数码管的显示(附带源码)

单片机实现蜂鸣器数码管显示 蜂鸣器和数码管在嵌入式系统中广泛应用。蜂鸣器可以发出声音警告或提示&#xff0c;而数码管则用于显示数字或字母。在本项目中&#xff0c;我们将通过8051单片机实现一个控制蜂鸣器和数码管显示的系统&#xff0c;结合使用蜂鸣器和数码管&#xf…...

C语言期末复习笔记(上)

目录 一、为什么要学习C语言 1.C语言适合做什么 2.开发C程序的步骤 3.常用术语 二、C语言数据结构 1.常量与变量 &#xff08;1&#xff09;常量 ​编辑 &#xff08;2&#xff09;变量 2.数据类型 ​编辑 &#xff08;1&#xff09;数据类型的分类 &#xff08;2&a…...

HarmonyOS 实时监听与获取 Wi-Fi 信息

文章目录 摘要项目功能概述代码模块详细说明创建 Wi-Fi 状态保存对象Wi-Fi 状态监听模块获取当前 Wi-Fi 信息整合主模块 运行效果展示性能分析总结 摘要 本文展示了如何使用 HarmonyOS 框架开发一个 Demo&#xff0c;用于监听手机的 Wi-Fi 状态变化并实时获取连接的 Wi-Fi 信息…...

Unity超优质动态天气插件(含一年四季各种天气变化,可用于单机局域网VR)

效果展示&#xff1a;https://www.bilibili.com/video/BV1CkkcYHENf/?spm_id_from333.1387.homepage.video_card.click 在你的项目中设置enviro真的很容易&#xff01;导入包裹并按照以下步骤操作开始的步骤&#xff01; 1. 拖拽“EnviroSky”预制件&#xff08;“environme…...

1 JVM JDK JRE之间的区别以及使用字节码的好处

JDK jdk是编译java源文件成class文件的&#xff0c;我们使用javac命令把java源文件编译成class文件。 我们在java安装的目录下找到bin文件夹&#xff0c;如下图所示: 遵循着编译原理&#xff0c;把java源文件编译成JVM可识别的机器码。 其中还包括jar打包工具等。主要是针对…...

【网络安全】网站常见安全漏洞—服务端漏洞介绍

文章目录 网站常见安全漏洞—服务端漏洞介绍引言1. 第三方组件漏洞什么是第三方组件漏洞&#xff1f;如何防范&#xff1f; 2. SQL 注入什么是SQL注入&#xff1f;如何防范&#xff1f; 3. 命令执行漏洞什么是命令执行漏洞&#xff1f;如何防范&#xff1f; 4. 越权漏洞什么是越…...

MAPTR:在线矢量化高精地图构建的结构化建模与学习(2208)

MAPTR: STRUCTURED MODELING AND LEARNING FOR ONLINE VECTORIZED HD MAP CONSTRUCTION MAPTR&#xff1a;在线矢量化高精地图构建的结构化建模与学习 ABSTRACT High-definition (HD) map provides abundant and precise environmental information of the driving scene, se…...

基于容器的云原生,让业务更自由地翱翔云端

无论是要构建一个应用或开发一个更庞大的解决方案&#xff0c;在技术选型时&#xff0c;技术的开放性和可移植性已经成为很多企业优先考虑的问题之一。毕竟没人希望自己未来的发展方向和成长速度被自己若干年前选择使用的某项技术所限制或拖累。 那么当你的业务已经上云&#x…...

大屏开源项目go-view二次开发2----半环形控件(C#)

环境搭建参考&#xff1a; 大屏开源项目go-view二次开发1----环境搭建(C#)-CSDN博客 要做的半环形控件最终效果如下图&#xff1a; 步骤如下&#xff1a; 1 在go-view前端项目的\src\packages\components\Charts目录下新增Others目录&#xff0c;并在Others目录下新增PieExt…...

web:pc端企业微信登录-vue版

官方文档&#xff1a;developer.work.weixin.qq.com/document/pa… 不需要调用ww.register&#xff0c;直接调用ww.createWWLoginPanel即可创建企业微信登录面板 - 文档 - 企业微信开发者中心 (qq.com) 引入 //通过 npm 引入 npm install wecom/jssdk import * as ww from we…...

OpenGL ES 01 渲染一个四边形

项目架构 着色器封装 vertex #version 300 es // 接收顶点数据 layout (location 0) in vec3 aPos; // 位置变量的属性位置值为0 layout (location 1) in vec4 aColors; // 位置变量的属性位置值为1 out vec4 vertexColor; // 为片段着色器指定一个颜色输出void main() {gl…...

【ETCD】【源码阅读】深入解析 EtcdServer.applyEntries方法

applyEntries方法的主要作用是接收待应用的 Raft 日志条目&#xff0c;并按顺序将其应用到系统中&#xff1b;确保条目的索引连续&#xff0c;避免丢失或重复应用条目。 一、函数完整代码 func (s *EtcdServer) applyEntries(ep *etcdProgress, apply *apply) {if len(apply.…...

概率论得学习和整理28:用EXCEL画折线图,X轴数据也被当成曲线的解决办法

目录 1 折线图和散点图&#xff0c;对数据的处理差别 1.1 EXCEL画图的一些默认设置 1.2 多于2列的数据&#xff0c;也是如此 2 如果我们非要以第1列数据为X轴&#xff0c;做一个折线图呢&#xff1f;也能 2.1 首先&#xff0c;把第1列&#xff0c;想当成X轴的数据&#xf…...

tryhackme-Pre Security-Defensive Security Intro(防御安全简介)

任务一&#xff1a;Introduction to Defensive Security防御安全简介 此room的两个要点&#xff1a; Preventing intrusions from occurring 防止入侵发生Detecting intrusions when they occur and responding properly 检测发生的入侵并正确响应 防御安全还有更多内容。 除上…...

【Axure高保真原型】引导弹窗

今天和大家中分享引导弹窗的原型模板&#xff0c;载入页面后&#xff0c;会显示引导弹窗&#xff0c;适用于引导用户使用页面&#xff0c;点击完成后&#xff0c;会显示下一个引导弹窗&#xff0c;直至最后一个引导弹窗完成后进入首页。具体效果可以点击下方视频观看或打开下方…...

【人工智能】神经网络的优化器optimizer(二):Adagrad自适应学习率优化器

一.自适应梯度算法Adagrad概述 Adagrad&#xff08;Adaptive Gradient Algorithm&#xff09;是一种自适应学习率的优化算法&#xff0c;由Duchi等人在2011年提出。其核心思想是针对不同参数自动调整学习率&#xff0c;适合处理稀疏数据和不同参数梯度差异较大的场景。Adagrad通…...

Spring Boot 实现流式响应(兼容 2.7.x)

在实际开发中&#xff0c;我们可能会遇到一些流式数据处理的场景&#xff0c;比如接收来自上游接口的 Server-Sent Events&#xff08;SSE&#xff09; 或 流式 JSON 内容&#xff0c;并将其原样中转给前端页面或客户端。这种情况下&#xff0c;传统的 RestTemplate 缓存机制会…...

多模态大语言模型arxiv论文略读(108)

CROME: Cross-Modal Adapters for Efficient Multimodal LLM ➡️ 论文标题&#xff1a;CROME: Cross-Modal Adapters for Efficient Multimodal LLM ➡️ 论文作者&#xff1a;Sayna Ebrahimi, Sercan O. Arik, Tejas Nama, Tomas Pfister ➡️ 研究机构: Google Cloud AI Re…...

无人机侦测与反制技术的进展与应用

国家电网无人机侦测与反制技术的进展与应用 引言 随着无人机&#xff08;无人驾驶飞行器&#xff0c;UAV&#xff09;技术的快速发展&#xff0c;其在商业、娱乐和军事领域的广泛应用带来了新的安全挑战。特别是对于关键基础设施如电力系统&#xff0c;无人机的“黑飞”&…...

uniapp手机号一键登录保姆级教程(包含前端和后端)

目录 前置条件创建uniapp项目并关联uniClound云空间开启一键登录模块并开通一键登录服务编写云函数并上传部署获取手机号流程(第一种) 前端直接调用云函数获取手机号&#xff08;第三种&#xff09;后台调用云函数获取手机号 错误码常见问题 前置条件 手机安装有sim卡手机开启…...

6️⃣Go 语言中的哈希、加密与序列化:通往区块链世界的钥匙

Go 语言中的哈希、加密与序列化:通往区块链世界的钥匙 一、前言:离区块链还有多远? 区块链听起来可能遥不可及,似乎是只有密码学专家和资深工程师才能涉足的领域。但事实上,构建一个区块链的核心并不复杂,尤其当你已经掌握了一门系统编程语言,比如 Go。 要真正理解区…...

结构化文件管理实战:实现目录自动创建与归类

手动操作容易因疲劳或疏忽导致命名错误、路径混乱等问题&#xff0c;进而引发后续程序异常。使用工具进行标准化操作&#xff0c;能有效降低出错概率。 需要快速整理大量文件的技术用户而言&#xff0c;这款工具提供了一种轻便高效的解决方案。程序体积仅有 156KB&#xff0c;…...

基于小程序老人监护管理系统源码数据库文档

摘 要 近年来&#xff0c;随着我国人口老龄化问题日益严重&#xff0c;独居和居住养老机构的的老年人数量越来越多。而随着老年人数量的逐步增长&#xff0c;随之而来的是日益突出的老年人问题&#xff0c;尤其是老年人的健康问题&#xff0c;尤其是老年人产生健康问题后&…...

Qt学习及使用_第1部分_认识Qt---Qt开发基本流程

前言 学以致用,通过QT框架的学习,一边实践,一边探索编程的方方面面. 参考书:<Qt 6 C开发指南>(以下称"本书") 标识说明:概念用粗体倾斜.重点内容用(加粗黑体)---重点内容(红字)---重点内容(加粗红字), 本书原话内容用深蓝色标识,比较重要的内容用加粗倾…...