当前位置: 首页 > news >正文

基于ChatYuan-large-v2 语言模型 Fine-tuning 微调训练 广告生成 任务

一、ChatYuan-large-v2

ChatYuan-large-v2是一个开源的支持中英双语的功能型对话语言大模型,与其他 LLM 不同的是模型十分轻量化,并且在轻量化的同时效果相对还不错,仅仅通过0.7B参数量就可以实现10B模型的基础效果,正是其如此的轻量级,使其可以在普通显卡、 CPU、甚至手机上进行推理,而且 INT4 量化后的最低只需 400M

v2 版本相对于以前的 v1 版本,是使用了相同的技术方案,但在指令微调、人类反馈强化学习、思维链等方面进行了优化。

在本专栏前面文章介绍了 ChatYuan-large-v2langchain 相结合的使用,地址如下:

LangChain 本地化方案 - 使用 ChatYuan-large-v2 作为 LLM 大语言模型

本篇文章以 ChatYuan-large-v2 模型为基础 Fine-tuning 广告生成 任务。

二、数据集处理

数据集这里使用 ChatGLM 官方在 Fine-tuning 中使用到的 广告生成 数据集。

下载地址如下:

https://drive.google.com/file/d/13_vf0xRTQsyneRKdD1bZIr93vBGOczrk/view

数据已 JSON 的形式存放,分为了 traindev 两种类型:

在这里插入图片描述
数据格式如下所示:

{"content":"类型#裤*版型#宽松*风格#性感*图案#线条*裤型#阔腿裤","summary":"宽松的阔腿裤这两年真的吸粉不少,明星时尚达人的心头爱。毕竟好穿时尚,谁都能穿出腿长2米的效果宽松的裤腿,当然是遮肉小能手啊。上身随性自然不拘束,面料亲肤舒适贴身体验感棒棒哒。系带部分增加设计看点,还让单品的设计感更强。腿部线条若隐若现的,性感撩人。颜色敲温柔的,与裤子本身所呈现的风格有点反差萌。"
}
{"content":"类型#裤*版型#宽松*风格#性感*图案#线条*裤型#阔腿裤","summary":"宽松的阔腿裤这两年真的吸粉不少,明星时尚达人的心头爱。毕竟好穿时尚,谁都能穿出腿长2米的效果宽松的裤腿,当然是遮肉小能手啊。上身随性自然不拘束,面料亲肤舒适贴身体验感棒棒哒。系带部分增加设计看点,还让单品的设计感更强。腿部线条若隐若现的,性感撩人。颜色敲温柔的,与裤子本身所呈现的风格有点反差萌。"
}
{"content":"类型#裙*风格#简约*图案#条纹*图案#线条*图案#撞色*裙型#鱼尾裙*裙袖长#无袖","summary":"圆形领口修饰脖颈线条,适合各种脸型,耐看有气质。无袖设计,尤显清凉,简约横条纹装饰,使得整身人鱼造型更为生动立体。加之撞色的鱼尾下摆,深邃富有诗意。收腰包臀,修饰女性身体曲线,结合别出心裁的鱼尾裙摆设计,勾勒出自然流畅的身体轮廓,展现了婀娜多姿的迷人姿态。"
}
{"content":"类型#上衣*版型#宽松*颜色#粉红色*图案#字母*图案#文字*图案#线条*衣样式#卫衣*衣款式#不规则","summary":"宽松的卫衣版型包裹着整个身材,宽大的衣身与身材形成鲜明的对比描绘出纤瘦的身形。下摆与袖口的不规则剪裁设计,彰显出时尚前卫的形态。被剪裁过的样式呈现出布条状自然地垂坠下来,别具有一番设计感。线条分明的字母样式有着花式的外观,棱角分明加上具有少女元气的枣红色十分有年轻活力感。粉红色的衣身把肌肤衬托得很白嫩又健康。"
}
{"content":"类型#裙*版型#宽松*材质#雪纺*风格#清新*裙型#a字*裙长#连衣裙","summary":"踩着轻盈的步伐享受在午后的和煦风中,让放松与惬意感为你免去一身的压力与束缚,仿佛要将灵魂也寄托在随风摇曳的雪纺连衣裙上,吐露出<UNK>微妙而又浪漫的清新之意。宽松的a字版型除了能够带来足够的空间,也能以上窄下宽的方式强化立体层次,携带出自然优雅的曼妙体验。"
}
{"content":"类型#上衣*材质#棉*颜色#蓝色*风格#潮*衣样式#polo*衣领型#polo领*衣袖长#短袖*衣款式#拼接","summary":"想要在人群中脱颖而出吗?那么最适合您的莫过于这款polo衫短袖,采用了经典的polo领口和柔软纯棉面料,让您紧跟时尚潮流。再配合上潮流的蓝色拼接设计,使您的风格更加出众。就算单从选料上来说,这款polo衫的颜色沉稳经典,是这个季度十分受大众喜爱的风格了,而且兼具舒适感和时尚感。"
}

其中任务的方式为根据输入(content)生成一段广告词(summary)。

train.json 共有 114599 条记录,这里为了演示效果取前 50000 条数据进行训练、5000 条数据进行验证:

import os# 将训练集进行提取
def doHandle(json_path, train_size, val_size, out_json_path):train_count = 0val_count = 0train_f = open(os.path.join(out_json_path, "train.json"), "a", encoding='utf-8')val_f = open(os.path.join(out_json_path, "val.json"), "a", encoding='utf-8')with open(json_path, "r", encoding='utf-8') as f:for line in f:if train_count < train_size:train_f.writelines(line)train_count = train_count + 1elif val_count < val_size:val_f.writelines(line)val_count = val_count + 1else:breakprint("数据处理完毕!")train_f.close()val_f.close()if __name__ == '__main__':json_path = "./data/AdvertiseGen/train.json"out_json_path = "./data/"train_size = 50000val_size = 5000doHandle(json_path, train_size, val_size, out_json_path)

处理之后可以看到两个生成的文件:

在这里插入图片描述

下面基于上面的数据格式构建 Dataset

from torch.utils.data import Dataset, DataLoader, RandomSampler, SequentialSampler
import torch
import jsonclass SummaryDataSet(Dataset):def __init__(self, json_path: str, tokenizer, max_length=300):self.tokenizer = tokenizerself.max_length = max_lengthself.content_data = []self.summary_data = []with open(json_path, "r", encoding='utf-8') as f:for line in f:if not line or line == "":continuejson_line = json.loads(line)content = json_line["content"]summary = json_line["summary"]self.content_data.append(content)self.summary_data.append(summary)print("data load , size:", len(self.content_data))def __len__(self):return len(self.content_data)def __getitem__(self, index):source_text = str(self.content_data[index])target_text = str(self.summary_data[index])source = self.tokenizer.batch_encode_plus([source_text],max_length=self.max_length,pad_to_max_length=True,truncation=True,padding="max_length",return_tensors="pt",)target = self.tokenizer.batch_encode_plus([target_text],max_length=self.max_length,pad_to_max_length=True,truncation=True,padding="max_length",return_tensors="pt",)source_ids = source["input_ids"].squeeze()source_mask = source["attention_mask"].squeeze()target_ids = target["input_ids"].squeeze()target_mask = target["attention_mask"].squeeze()return {"source_ids": source_ids.to(dtype=torch.long),"source_mask": source_mask.to(dtype=torch.long),"target_ids": target_ids.to(dtype=torch.long),"target_ids_y": target_ids.to(dtype=torch.long),}

三、模型训练

下载 ChatYuan-large-v2 模型:

https://huggingface.co/ClueAI/ChatYuan-large-v2/tree/main

在这里插入图片描述
在这里插入图片描述

下面基于 ChatYuan-large-v2 进行训练:

import pandas as pd
import torch
from torch.utils.data import DataLoader
import os, time
from transformers import T5Tokenizer, T5ForConditionalGeneration
from gen_dataset import SummaryDataSetdef train(epoch, tokenizer, model, device, loader, optimizer):model.train()time1 = time.time()for _, data in enumerate(loader, 0):y = data["target_ids"].to(device, dtype=torch.long)y_ids = y[:, :-1].contiguous()lm_labels = y[:, 1:].clone().detach()lm_labels[y[:, 1:] == tokenizer.pad_token_id] = -100ids = data["source_ids"].to(device, dtype=torch.long)mask = data["source_mask"].to(device, dtype=torch.long)outputs = model(input_ids=ids,attention_mask=mask,decoder_input_ids=y_ids,labels=lm_labels,)loss = outputs[0]# 每100步打印日志if _ % 100 == 0 and _ != 0:time2 = time.time()print(_, "epoch:" + str(epoch) + "-loss:" + str(loss) + ";each step's time spent:" + str(float(time2 - time1) / float(_ + 0.0001)))optimizer.zero_grad()loss.backward()optimizer.step()def validate(epoch, tokenizer, model, device, loader, max_length):model.eval()predictions = []actuals = []with torch.no_grad():for _, data in enumerate(loader, 0):y = data['target_ids'].to(device, dtype=torch.long)ids = data['source_ids'].to(device, dtype=torch.long)mask = data['source_mask'].to(device, dtype=torch.long)generated_ids = model.generate(input_ids=ids,attention_mask=mask,max_length=max_length,num_beams=2,repetition_penalty=2.5,length_penalty=1.0,early_stopping=True)preds = [tokenizer.decode(g, skip_special_tokens=True, clean_up_tokenization_spaces=True) for g ingenerated_ids]target = [tokenizer.decode(t, skip_special_tokens=True, clean_up_tokenization_spaces=True) for t in y]if _ % 1000 == 0:print(f'Completed {_}')predictions.extend(preds)actuals.extend(target)return predictions, actualsdef T5Trainer(train_json_path, val_json_path, model_dir, batch_size, epochs, output_dir, max_length=300):tokenizer = T5Tokenizer.from_pretrained(model_dir)model = T5ForConditionalGeneration.from_pretrained(model_dir)device = torch.device("cuda" if torch.cuda.is_available() else "cpu")model = model.to(device)train_params = {"batch_size": batch_size,"shuffle": True,"num_workers": 0,}training_set = SummaryDataSet(train_json_path, tokenizer, max_length=max_length)training_loader = DataLoader(training_set, **train_params)val_params = {"batch_size": batch_size,"shuffle": False,"num_workers": 0,}val_set = SummaryDataSet(val_json_path, tokenizer, max_length=max_length)val_loader = DataLoader(val_set, **val_params)optimizer = torch.optim.Adam(params=model.parameters(), lr=1e-4)for epoch in range(epochs):train(epoch, tokenizer, model, device, training_loader, optimizer)print("保存模型")model.save_pretrained(output_dir)tokenizer.save_pretrained(output_dir)# 验证with torch.no_grad():predictions, actuals = validate(epoch, tokenizer, model, device, val_loader, max_length)# 验证结果存储final_df = pd.DataFrame({"Generated Text": predictions, "Actual Text": actuals})final_df.to_csv(os.path.join(output_dir, "predictions.csv"))if __name__ == '__main__':train_json_path = "./data/train.json"val_json_path = "./data/val.json"# 下载模型目录位置model_dir = "chatyuan_large_v2"batch_size = 5epochs = 1max_length = 300output_dir = "./model"# 开始训练T5Trainer(train_json_path,val_json_path,model_dir,batch_size,epochs,output_dir,max_length)

运行后可以看到如下日志打印,训练大概占用 33G 的显存,如果显存不够可以调低些 batch_size 的大小:

在这里插入图片描述

等待训练结束后:

在这里插入图片描述

可以在 model 下看到保存的模型:

在这里插入图片描述

这里可以先看下 predictions.csv 验证集的效果:

在这里插入图片描述

可以看到模型生成的结果有点不太好,这里仅对前 50000 条进行了训练,并且就训练了一个 epoch ,后面可以增加数据集大小和增加 epoch 应该能达到更好的效果,下面通过调用模型测试一下生成的文本效果。

四、模型测试

# -*- coding: utf-8 -*-
from transformers import T5Tokenizer, T5ForConditionalGeneration
import torch# 这里是模型下载的位置
model_dir = './model'tokenizer = T5Tokenizer.from_pretrained(model_dir)
model = T5ForConditionalGeneration.from_pretrained(model_dir)while True:text = input("请输入内容: \n ")if not text or text == "":continueif text == "q":breakencoded_input = tokenizer(text, padding="max_length", truncation=True, max_length=300)input_ids = torch.tensor([encoded_input['input_ids']])attention_mask = torch.tensor([encoded_input['attention_mask']])generated_ids = model.generate(input_ids=input_ids,attention_mask=attention_mask,max_length=300,num_beams=2,repetition_penalty=2.5,length_penalty=1.0,early_stopping=True)reds = [tokenizer.decode(g, skip_special_tokens=True, clean_up_tokenization_spaces=True) for g ingenerated_ids]print(reds)

效果测试:

在这里插入图片描述

相关文章:

基于ChatYuan-large-v2 语言模型 Fine-tuning 微调训练 广告生成 任务

一、ChatYuan-large-v2 ChatYuan-large-v2是一个开源的支持中英双语的功能型对话语言大模型&#xff0c;与其他 LLM 不同的是模型十分轻量化&#xff0c;并且在轻量化的同时效果相对还不错&#xff0c;仅仅通过0.7B参数量就可以实现10B模型的基础效果&#xff0c;正是其如此的…...

SpringBoot集成Logback日志

SpringBoot集成Logback日志 文章目录 SpringBoot集成Logback日志一、什么是日志二、Logback简单介绍三、SpringBoot项目中使用Logback四、概念介绍一、日志记录器Logger1.1、日志记录器对象生成1.2、记录器的层级结构1.3、过滤器1.4、logger设置日志级别1.5、java代码演示1.6、…...

MATLAB(R2023a)添加工具箱TooLbox的方法-以GPOPS为例

一、找到工具箱存放位置 首先我们需要找到工具箱的存放位置&#xff0c;点击这个设置路径可以看到 我们的matlab工具箱的存放位置 C:\Program Files\MATLAB\R2023a\toolbox\matlab 从资源管理器中打开这个位置&#xff0c;可以看到里面各种工具箱 二、放入工具箱 解压我们…...

助力618-Y的混沌实践之路 | 京东云技术团队

一、写在前面 1、混沌是什么&#xff1f; 混沌工程&#xff08;Chaos Engineering&#xff09;的概念由 Netflix 在 2010 年提出&#xff0c;通过主动向系统中引入异常状态&#xff0c;并根据系统在各种压力下的行为表现确定优化策略&#xff0c;是保障系统稳定性的新型手段。…...

Python系统学习1-4-物理行、逻辑行、选择语句

一、行 (1) 物理行&#xff1a;程序员编写代码的行。 (2) 逻辑行&#xff1a;python解释器需要执行的指令。 (3) 建议&#xff1a; 一个逻辑行在一个物理行上。 如果一个物理行中使用多个逻辑行&#xff0c;需要使用分号&#xff1b;隔开。 (4) 换行&#xff1a; 如果…...

学习系统编程No.35【基于信号量的CP问题】

引言&#xff1a; 北京时间&#xff1a;2023/8/2/12:52&#xff0c;时间飞逝&#xff0c;恍惚间已经来到了八月&#xff0c;给我的第一感觉就是快开学了&#xff0c;别的感觉其实没有&#xff0c;哈哈&#xff01;看着身边的好友网络相关知识都要全部学完了&#xff0c;就好像…...

词嵌入、情感分类任务

目录 1.词嵌入&#xff08;word embedding&#xff09; 对单词使用one-hot编码的缺点是难以看出词与词之间的关系。 所以需要使用更加特征化的表示&#xff08;featurized representation&#xff09;&#xff0c;如下图所示&#xff0c;我们可以得到每个词的向量表达。 假设…...

TypeScript使用技巧

文章目录 使用技巧TypeScript内置的工具类型keyofextends 限定泛型interface 与 type 区别 TypeScript作为JavaScript的超集,通过提供静态类型系统和对ES6新特性的支持,使JavaScript开发变得更加高效和可维护。掌握TypeScript的使用技巧,可以帮助我们更好地开发和组织JavaScrip…...

MySQL — InnoDB事务

文章目录 事务定义事务特性事务隔离级别READ UNCOMMITTEDREPEATABLE READREAD COMMITTEDSERIALIZABLE 事务存在的问题脏读&#xff08;Dirty Read&#xff09;不可重复读&#xff08;Non-repeatable Read&#xff09;幻读&#xff08;Phantom Read&#xff09; 事务定义 数据库…...

LeetCode 42. 接雨水(动态规划 / 单调栈)

题目&#xff1a; 链接&#xff1a;LeetCode 42. 接雨水 难度&#xff1a;困难 给定 n 个非负整数表示每个宽度为 1 的柱子的高度图&#xff0c;计算按此排列的柱子&#xff0c;下雨之后能接多少雨水。 示例 1&#xff1a; 输入&#xff1a;height [0,1,0,2,1,0,1,3,2,1,2…...

顺序表、链表刷题指南(力扣OJ)

目录 前言 题目一&#xff1a;删除有序数组中的重复项 思路&#xff1a; 题解&#xff1a; 题目二&#xff1a;合并两个有序数组 思路&#xff1a; 分析&#xff1a; 题解&#xff1a; 题目三&#xff1a;反转链表 思路&#xff1a; 分析&#xff1a; 题解&#xff1a; 题目四&…...

Lambda表达式总结

Lambda作为Java8的新特性&#xff0c;本篇文章主要想总结一下常用的一下用法和api 1.接口内默认方法实现 public interface Formula {double calculate(int a);// 默认方法default double sqrt(int a) {return Math.sqrt(a);} }public static void main(String[] args) {Form…...

岛屿的最大面积

给你一个大小为 m x n 的二进制矩阵 grid 。 岛屿 是由一些相邻的 1 (代表土地) 构成的组合&#xff0c;这里的「相邻」要求两个 1 必须在 水平或者竖直的四个方向上 相邻。你可以假设 grid 的四个边缘都被 0&#xff08;代表水&#xff09;包围着。 岛屿的面积是岛上值为 1 …...

迭代器模式(Iterator)

迭代器模式是一种行为设计模式&#xff0c;可以在不暴露底层实现(列表、栈或树等)的情况下&#xff0c;遍历一个聚合对象中所有的元素。 Iterator is a behavior design pattern that can traverse all elements of an aggregate object without exposing the internal imple…...

Goland搭建远程Linux开发

Windows和Linux都需要先构建好go环境&#xff0c;启用ssh服务。 打开Windows上的Goland&#xff0c;建立项目。 点击添加配置&#xff0c;选择go构建 点击运行于&#xff0c;选择ssh 填上Linux机器的IP地址和用户名 输入密码 没有问题 为了不让每次运行程序和调试程序都生…...

react中PureComponent的理解与使用

一、作用 它是一个纯组件&#xff0c;会做一个数据的浅比较&#xff0c;当props和state没改变的时候&#xff0c;不会render重新渲染&#xff0c; 改变后才会render重新渲染&#xff0c;提高性能。 二、使用 三、注意 它不能和shouldComponentUpdate生命周期同时使用。因为它…...

洛谷——P5714 【深基3.例7】肥胖问题

文章目录 题目题目描述输入格式输出格式样例 #1样例输入 #1样例输出 #1 样例 #2样例输入 #2样例输出 #2 提示 AC代码 题目 题目描述 BMI 指数是国际上常用的衡量人体胖瘦程度的一个标准&#xff0c;其算法是 m h 2 \dfrac{m}{h^2} h2m​&#xff0c;其中 m m m 是指体重&am…...

Mac隐藏和显示文件

由于之前没有使用过Mac本&#xff0c;所以很多地方都不太清楚&#xff0c;在下载git项目的时候&#xff0c;发现没有.git文件&#xff0c; 一开始还以为下载错了&#xff0c;但是git命令是可以看到远端分支以及当前分支的&#xff0c;之后在一次解压文件的时候发现&#xff0c;…...

软件工程中应用的几种图辨析

【软件工程】软件工程中应用的几种图辨析&#xff1a;系统流程图、数据流图、数据字典、实体联系图、状态转换图、层次方框图、Warnier图、IPO图、层次图、HIPO图、结构图、程序流程图、盒图、PAD图、判定表_眩晕李的博客-CSDN博客 软件工程——实体关系图 状态转换图 数据流…...

下载离线版的VS Visual Studio 并下载指定的版本

一、先下载引导程序 下载地址VS VisualStudio官网 在这个页面翻到最下面 在这里下载需要的版本 下载引导程序 二、下载离线安装包 写一个批处理文件&#xff08;vs.bat&#xff09; 命令格式如下 <vs引导程序exe> --layout <离线安装包下载的路径> --add <功能…...

逻辑回归:给不确定性划界的分类大师

想象你是一名医生。面对患者的检查报告&#xff08;肿瘤大小、血液指标&#xff09;&#xff0c;你需要做出一个**决定性判断**&#xff1a;恶性还是良性&#xff1f;这种“非黑即白”的抉择&#xff0c;正是**逻辑回归&#xff08;Logistic Regression&#xff09;** 的战场&a…...

Java - Mysql数据类型对应

Mysql数据类型java数据类型备注整型INT/INTEGERint / java.lang.Integer–BIGINTlong/java.lang.Long–––浮点型FLOATfloat/java.lang.FloatDOUBLEdouble/java.lang.Double–DECIMAL/NUMERICjava.math.BigDecimal字符串型CHARjava.lang.String固定长度字符串VARCHARjava.lang…...

【ROS】Nav2源码之nav2_behavior_tree-行为树节点列表

1、行为树节点分类 在 Nav2(Navigation2)的行为树框架中,行为树节点插件按照功能分为 Action(动作节点)、Condition(条件节点)、Control(控制节点) 和 Decorator(装饰节点) 四类。 1.1 动作节点 Action 执行具体的机器人操作或任务,直接与硬件、传感器或外部系统…...

vue3 定时器-定义全局方法 vue+ts

1.创建ts文件 路径&#xff1a;src/utils/timer.ts 完整代码&#xff1a; import { onUnmounted } from vuetype TimerCallback (...args: any[]) > voidexport function useGlobalTimer() {const timers: Map<number, NodeJS.Timeout> new Map()// 创建定时器con…...

零基础设计模式——行为型模式 - 责任链模式

第四部分&#xff1a;行为型模式 - 责任链模式 (Chain of Responsibility Pattern) 欢迎来到行为型模式的学习&#xff01;行为型模式关注对象之间的职责分配、算法封装和对象间的交互。我们将学习的第一个行为型模式是责任链模式。 核心思想&#xff1a;使多个对象都有机会处…...

mysql已经安装,但是通过rpm -q 没有找mysql相关的已安装包

文章目录 现象&#xff1a;mysql已经安装&#xff0c;但是通过rpm -q 没有找mysql相关的已安装包遇到 rpm 命令找不到已经安装的 MySQL 包时&#xff0c;可能是因为以下几个原因&#xff1a;1.MySQL 不是通过 RPM 包安装的2.RPM 数据库损坏3.使用了不同的包名或路径4.使用其他包…...

【数据分析】R版IntelliGenes用于生物标志物发现的可解释机器学习

禁止商业或二改转载&#xff0c;仅供自学使用&#xff0c;侵权必究&#xff0c;如需截取部分内容请后台联系作者! 文章目录 介绍流程步骤1. 输入数据2. 特征选择3. 模型训练4. I-Genes 评分计算5. 输出结果 IntelliGenesR 安装包1. 特征选择2. 模型训练和评估3. I-Genes 评分计…...

智能AI电话机器人系统的识别能力现状与发展水平

一、引言 随着人工智能技术的飞速发展&#xff0c;AI电话机器人系统已经从简单的自动应答工具演变为具备复杂交互能力的智能助手。这类系统结合了语音识别、自然语言处理、情感计算和机器学习等多项前沿技术&#xff0c;在客户服务、营销推广、信息查询等领域发挥着越来越重要…...

iOS性能调优实战:借助克魔(KeyMob)与常用工具深度洞察App瓶颈

在日常iOS开发过程中&#xff0c;性能问题往往是最令人头疼的一类Bug。尤其是在App上线前的压测阶段或是处理用户反馈的高发期&#xff0c;开发者往往需要面对卡顿、崩溃、能耗异常、日志混乱等一系列问题。这些问题表面上看似偶发&#xff0c;但背后往往隐藏着系统资源调度不当…...

【电力电子】基于STM32F103C8T6单片机双极性SPWM逆变(硬件篇)

本项目是基于 STM32F103C8T6 微控制器的 SPWM(正弦脉宽调制)电源模块,能够生成可调频率和幅值的正弦波交流电源输出。该项目适用于逆变器、UPS电源、变频器等应用场景。 供电电源 输入电压采集 上图为本设计的电源电路,图中 D1 为二极管, 其目的是防止正负极电源反接, …...