深度学习笔记(8)预训练模型
深度学习笔记(8)预训练模型
文章目录
- 深度学习笔记(8)预训练模型
- 一、预训练模型构建
- 一、微调模型,训练自己的数据
- 1.导入数据集
- 2.数据集处理方法
- 3.完形填空训练
- 使用分词器将文本转换为模型的输入格式
- 参数 return_tensors="pt" 表示返回PyTorch张量格式
- 执行模型预测
一、预训练模型构建
加载模型和之前一样,用别人弄好的
# 导入warnings模块,用于忽略后续代码中可能出现的警告信息
import warnings
# 设置warnings模块,使其忽略所有警告
warnings.filterwarnings("ignore")# 从transformers库中导入AutoModelForMaskedLM类,该类用于预训练的掩码语言模型
from transformers import AutoModelForMaskedLM# 指定模型检查点,这里使用的是distilbert-base-uncased模型
model_checkpoint = "distilbert-base-uncased"
# 使用from_pretrained方法加载预训练的模型,该方法将从指定的检查点加载模型
model = AutoModelForMaskedLM.from_pretrained(model_checkpoint)
咱们的任务就是去预测MASK到底是个啥
text = "This is a great " # 定义一个文本字符串# 从指定的模型检查点加载分词器
# model_checkpoint 是之前定义的模型检查点路径,用于加载与模型配套的分词器
tokenizer = AutoTokenizer.from_pretrained(model_checkpoint)# 使用分词器将文本转换为模型的输入格式
# 参数 return_tensors="pt" 表示返回PyTorch张量格式
inputs = tokenizer(text, return_tensors="pt")# inputs 现在是一个包含模型输入的张量或字典,可以用于模型推理
{'input_ids': tensor([[ 101, 2023, 2003, 1037, 2307, 103, 1012, 102]]), 'attention_mask': tensor([[1, 1, 1, 1, 1, 1, 1, 1]])}
下面的代码可以看到mask的id是103
tokenizer.mask_token_id
103
一、微调模型,训练自己的数据
1.导入数据集
from datasets import load_dataset # 从datasets库中导入load_dataset函数imdb_dataset = load_dataset("imdb") # 使用load_dataset函数加载IMDB数据集
这段代码是使用 datasets 库来加载 IMDB 数据集。IMDB 数据集是一个用于情感分析的经典数据集,包含了两类电影评论:正面和负面。
本身是带标签的 正面和负面
- 0 表示negative
- 1表示positive
先查看下数据集的数据
sample = imdb_dataset["train"].shuffle(seed=42).select(range(3)) # 从训练数据中随机选择3个样本for row in sample:print(f"\n'>>> Review: {row['text']}'") # 打印样本的文本内容print(f"'>>> Label: {row['label']}'") # 打印样本的标签
其中一个如下
'>>> Review: This movie is a great. The plot is very true to the book which is a classic written by Mark Twain. The movie starts of with a scene where Hank sings a song with a bunch of kids called "when you stub your toe on the moon" It reminds me of Sinatra's song High Hopes, it is fun and inspirational. The Music is great throughout and my favorite song is sung by the King, Hank (bing Crosby) and Sir "Saggy" Sagamore. OVerall a great family movie or even a great Date movie. This is a movie you can watch over and over again. The princess played by Rhonda Fleming is gorgeous. I love this movie!! If you liked Danny Kaye in the Court Jester then you will definitely like this movie.'
'>>> Label: 1'
但是我们要做完形填空,标签是没用的
2.数据集处理方法
这里文本长度要统一
- 计算每一个文本的长度(word_ids)
- 指定chunk_size,然后将所有数据按块进行拆分,比如每块128个,句子是700字节,要分成128,128,。。。。这种
先定义个函数,这样下面使用可以直接调用
def tokenize_function(examples):# 调用分词器(tokenizer)的函数,传入输入的文本数据集result = tokenizer(examples["text"])# 如果分词器支持快速模式,则生成单词索引(word_ids)if tokenizer.is_fast:result["word_ids"] = [result.word_ids(i) for i in range(len(result["input_ids"]))]# 返回转换后的结果return result
进行文本处理
tokenized_datasets = imdb_dataset.map(tokenize_function, # 应用tokenize_function函数到每个样本batched=True, # 是否将数据分成批次处理remove_columns=["text", "label"] # 要从数据集中移除的列
)
首先,imdb_dataset.map() 方法被用来应用 tokenize_function 函数到 imdb_dataset 的 train 部分。这个方法会对数据集中的每个样本应用指定的函数,并返回一个新的数据集,其中包含应用函数后的结果。
batched=True 参数告诉 map 方法将输入数据分成批次进行处理。这通常是为了提高效率,尤其是在处理大型数据集时。
remove_columns=[“text”, “label”] 参数告诉 map 方法在处理数据时移除指定的列。在这个例子中,它移除了 text 和 label 列,因为 text 列已经被处理为模型的输入,而 label 列不再需要,因为咱们是完形填空任务,不需要标签。
然后进行切分
tokenizer.model_max_length
chunk_size = 128
tokenizer.model_max_length 是一个属性,它表示模型能够接受的最大输入长度。这个属性通常用于序列标注任务,以确保输入的长度不超过模型的最大接受长度。
chunk_size 是一个参数,用于指定将输入序列分割成小块的大小。这个参数通常用于处理过长的输入序列,以便将其分割成多个小块,然后分别处理这些小块。
因为上限是512 所以你切分要是64 128 256这种
切分的时候也可以先看下文本长度
# 看看每一个都多长
tokenized_samples = tokenized_datasets["train"][:3]for idx, sample in enumerate(tokenized_samples["input_ids"]):print(f"'>>> Review {idx} length: {len(sample)}'")
'>>> Review 0 length: 363'
'>>> Review 1 length: 304'
'>>> Review 2 length: 133'
这里看出128切分比较合适
先拿着三个试试
concatenated_examples = {k: sum(tokenized_samples[k], []) for k in tokenized_samples.keys()#计算拼一起有多少个,
}
total_length = len(concatenated_examples["input_ids"])
print(f"'>>> Concatenated reviews length: {total_length}'")
'>>> Concatenated reviews length: 800'
下面进行切分
chunks = {# 使用字典推导式(dict comprehension)创建一个新的字典,其中键是序列的键(k),值是分割后的块k: [t[i : i + chunk_size] for i in range(0, total_length, chunk_size)]# 遍历concatenated_examples字典中的每个键值对for k, t in concatenated_examples.items()
}for chunk in chunks["input_ids"]:# 打印每个chunk的长度print(f"'>>> Chunk length: {len(chunk)}'")
'>>> Chunk length: 128'
'>>> Chunk length: 128'
'>>> Chunk length: 128'
'>>> Chunk length: 128'
'>>> Chunk length: 128'
'>>> Chunk length: 128'
'>>> Chunk length: 32'
最后那个不够数,直接给他丢弃
def group_texts(examples):# 将所有的文本实例拼接到一起concatenated_examples = {k: sum(examples[k], []) for k in examples.keys()}# 计算拼接后的总长度total_length = len(concatenated_examples[list(examples.keys())[0]])# 使用整除运算符(//)计算每个chunk的长度,然后乘以chunk_size,以确保不会出现多余的文本total_length = (total_length // chunk_size) * chunk_size# 根据计算出的总长度,对拼接后的文本进行切分result = {# 使用字典推导式(dict comprehension)创建一个新的字典,其中键是原始字典的键(k),值是切分后的块k: [t[i : i + chunk_size] for i in range(0, total_length, chunk_size)]# 遍历concatenated_examples字典中的每个键值对for k, t in concatenated_examples.items()}# 如果完型填空任务需要使用标签,则将标签复制到结果字典中result["labels"] = result["input_ids"].copy()# 返回分割后的结果return result
使用整除运算符(//)计算每个chunk的长度,然后乘以chunk_size这个就是说如果小于128 整除后就得0,就没了
train: Dataset({features: ['input_ids', 'attention_mask', 'word_ids', 'labels'],num_rows: 61291})test: Dataset({features: ['input_ids', 'attention_mask', 'word_ids', 'labels'],num_rows: 59904})unsupervised: Dataset({features: ['input_ids', 'attention_mask', 'word_ids', 'labels'],num_rows: 122957})
})
会发现数据量大了
3.完形填空训练
from transformers import DataCollatorForLanguageModeling # 从transformers库中导入DataCollatorForLanguageModeling类data_collator = DataCollatorForLanguageModeling(tokenizer=tokenizer, mlm_probability=0.15) # 创建一个数据收集器
然后再举个例子看看mask长啥样
samples = [lm_datasets["train"][i] for i in range(2)] # 从训练数据中选择2个样本
for sample in samples:_ = sample.pop("word_ids") # 移除样本中的"word_ids"键,因为发现没啥用print(sample)for chunk in data_collator(samples)["input_ids"]:print(f"\n'>>> {tokenizer.decode(chunk)}'")#tokenizer.decode(chunk)是为了让被隐藏的更明显点,所以直接decode出来print(len(chunk))
训练过程
from transformers import TrainingArguments # 从transformers库中导入TrainingArguments类batch_size = 64 # 定义训练和评估时的批次大小
# 计算每个epoch打印结果的步数
logging_steps = len(downsampled_dataset["train"]) // batch_size
model_name = model_checkpoint.split("/")[-1] # 从模型检查点路径中提取模型名称training_args = TrainingArguments(output_dir=f"{model_name}-finetuned-imdb", # 指定输出目录,其中包含微调后的模型和日志文件overwrite_output_dir=True, # 是否覆盖现有的输出目录evaluation_strategy="epoch", # 指定评估策略,这里为每个epoch评估一次learning_rate=2e-5, # 学习率weight_decay=0.01, # 权重衰减系数per_device_train_batch_size=batch_size, # 每个GPU的训练批次大小per_device_eval_batch_size=batch_size, # 每个GPU的评估批次大小logging_steps=logging_steps, # 指定每个epoch打印结果的步数num_train_epochs=1, # 指定训练的epoch数量save_strategy='epoch', # 指定保存策略,这里为每个epoch保存一次模型
)
生成的模型
from transformers import Trainer # 从transformers库中导入Trainer类trainer = Trainer(model=model, # 指定要训练的模型实例args=training_args, # 指定训练参数对象train_dataset=downsampled_dataset["train"], # 指定训练数据集eval_dataset=downsampled_dataset["test"], # 指定评估数据集data_collator=data_collator, # 指定数据收集器
)
评估标准使用困惑度
import math # 导入math模块,用于计算对数和指数eval_results = trainer.evaluate() # 使用Trainer的evaluate方法评估模型print(f">>> Perplexity: {math.exp(eval_results['eval_loss']):.2f}") # 打印评估结果中的对数
人话就是你不得在mask那挑啥词合适吗,平均挑了多少个才能答对
训练模型
trainer.train()
eval_results = trainer.evaluate()
print(f">>> Perplexity: {math.exp(eval_results['eval_loss']):.2f}")
发现困惑度降低了
from transformers import AutoModelForMaskedLMmodel_checkpoint = "./distilbert-base-uncased"
model = AutoModelForMaskedLM.from_pretrained("./distilbert-base-uncased-finetuned-imdb/checkpoint-157")
加载自己的模型
import torch # 导入torch库,用于处理张量
使用分词器将文本转换为模型的输入格式
参数 return_tensors=“pt” 表示返回PyTorch张量格式
inputs = tokenizer(text, return_tensors=“pt”)
执行模型预测
# model 是一个预训练的BERT模型实例
# **inputs 表示将inputs字典中的所有键值对作为关键字参数传递给model
token_logits = model(**inputs).logits# 找到遮蔽词在输入中的索引
# inputs["input_ids"] 是模型输入的词汇索引张量
# tokenizer.mask_token_id 是遮蔽词的词汇索引
mask_token_index = torch.where(inputs["input_ids"] == tokenizer.mask_token_id)[1]# 获取遮蔽词的预测logits
# mask_token_index 是遮蔽词在输入中的索引张量
# token_logits 是模型输出的预测logits张量
mask_token_logits = token_logits[0, mask_token_index, :]# 找到前5个最可能的替换词
# torch.topk 函数用于找到最大k个值及其索引
# dim=1 表示在第二个维度(即词汇维度)上进行排序
top_5_tokens = torch.topk(mask_token_logits, 5, dim=1).indices[0].tolist()# 打印每个最可能的替换词及其替换后的文本
for token in top_5_tokens:# 使用 tokenizer.decode 方法将索引转换为文本# text 是原始文本# tokenizer.mask_token 是遮蔽词# [token] 是替换词的索引列表print(f"'>>> {text.replace(tokenizer.mask_token, tokenizer.decode([token]))}'")
'>>> This is a great deal.'
'>>> This is a great idea.'
'>>> This is a great adventure.'
'>>> This is a great film.'
'>>> This is a great movie.'
可以看到结果和上面的通用的不一样了,这里是film movie这些针对你的训练数据的了--------
相关文章:

深度学习笔记(8)预训练模型
深度学习笔记(8)预训练模型 文章目录 深度学习笔记(8)预训练模型一、预训练模型构建一、微调模型,训练自己的数据1.导入数据集2.数据集处理方法3.完形填空训练 使用分词器将文本转换为模型的输入格式参数 return_tenso…...

C#事件的用法
前言 在C#中,事件(Event)可以实现当类内部发生某些特定的事情时,它可以通知其他类或对象。事件是基于委托(Delegate)的,委托是一种类型安全的函数指针,它定义了方法的类型ÿ…...

金砖软件测试赛项之Jmeter如何录制脚本!
一、简介 Apache JMeter 是一款开源的性能测试工具,用于测试各种服务的负载能力,包括Web应用、数据库、FTP服务器等。它可以模拟多种用户行为,生成负载以评估系统的性能和稳定性。 JMeter 的主要特点: 图形用户界面:…...

docker-squash镜像压缩
docker-squash 和 docker export docker load 的原理和效果有一些相似之处,但它们的工作方式和适用场景有所不同。 docker-squash docker-squash 是一个工具,它通过分析 Docker 镜像的层(layers)并将其压缩成更少的层来减小镜像…...

Vue3快速入门+axios的异步请求(基础使用)
学习Vue之前先要学习htmlcssjs的基础使用 Vue其实是js的框架 常用到的Vue指令包括vue-on,vue-for,vue-blind,vue-if&vue-show,v-modul vue的基础模板: <!DOCTYPE html> <html lang"zh-CN"> <head><meta charset"UTF-8&…...

VM16安装macOS11
注意: 本文内容于 2024-09-17 12:08:24 创建,可能不会在此平台上进行更新。如果您希望查看最新版本或更多相关内容,请访问原文地址:VM16安装macOS11。感谢您的关注与支持! 使用 Vmware Workstation Pro 16 安装 macOS…...

自定义复杂AntV/G6案例
一、效果图 二、源码 /** * * Author: me * CreatDate: 2024-08-22 * * Description: 复杂G6案例 * */ <template><div class"moreG6-wapper"><div id"graphContainer" ref"graphRef" class"graph-content"></d…...

Golang | Leetcode Golang题解之第419题棋盘上的战舰
题目: 题解: func countBattleships(board [][]byte) (ans int) {for i, row : range board {for j, ch : range row {if ch X && !(i > 0 && board[i-1][j] X || j > 0 && board[i][j-1] X) {ans}}}return }...

CCF刷题计划——LDAP(交集、并集 how to go)
LDAP 计算机软件能力认证考试系统 不知道为什么,直接给我报一个运行错误,得了0分。但是我在Dev里,VS里面都跑的好好的,奇奇怪怪。如果有大佬路过,请帮小弟看看QWQ。本题学到的:交集set_intersection、并集…...

谷歌论文提前揭示o1模型原理:AI大模型竞争或转向硬件
Open AI最强模型o1的护城河已经没有了?仅在OpenAI发布最新推理模型o1几日之后,海外社交平台 Reddit 上有网友发帖称谷歌Deepmind在 8 月发表的一篇论文内容与o1模型原理几乎一致,OpenAI的护城河不复存在。 谷歌DeepMind团队于今年8月6日发布…...

【ShuQiHere】 探索数据挖掘的世界:从概念到应用
🌐 【ShuQiHere】 数据挖掘(Data Mining, DM) 是一种从大型数据集中提取有用信息的技术,无论是在商业分析、金融预测,还是医学研究中,数据挖掘都扮演着至关重要的角色。本文将带您深入了解数据挖掘的核心概…...

LabVIEW提高开发效率技巧----使用事件结构优化用户界面响应
事件结构(Event Structure) 是 LabVIEW 中用于处理用户界面事件的强大工具。通过事件驱动的编程方式,程序可以在用户操作时动态执行特定代码,而不是通过轮询(Polling)的方式不断检查界面控件状态。这种方式…...

【前端】ES6:Set与Map
文章目录 1 Set结构1.1 初识Set1.2 实例的属性和方法1.3 遍历1.4 复杂数据结构去重 2 Map结构2.1 初识Map2.2 实例的属性和方法2.3 遍历 1 Set结构 它类似于数组,但成员的值都是唯一的,没有重复的值。 1.1 初识Set let s1 new Set([1, 2, 3, 2, 3]) …...

Java 之网络编程小案例
1. 多发多收 描述: 编写一个简单的聊天程序,客户端可以向服务器发送多条消息,服务器可以接收所有消息并回复。 代码示例: 服务器端 (Server.java): import java.io.*; import java.net.*; import java.util.concurrent.Execut…...

Spring Boot:现代化Java应用开发的艺术
目录 什么是Spring Boot? 为什么选择Spring Boot? Spring Boot的核心概念 详细步骤:创建一个Spring Boot应用 步骤1:使用Spring Initializr创建项目 步骤2:解压并导入项目 步骤3:构建和配置项目 po…...

Redis五种基本数据结构的使用
Redis具有五种基本数据类型:String(字符串)、Hash(哈希)、List(列表)、Set(集合)、SortedSet(有序集合),下面示意它们的使用。 String类数据类型的使用 增:添加数据(set)、添加多个数据(mset)、添加数据时指定过期时间(setex) 删…...

【QT】系统-下
欢迎来到Cefler的博客😁 🕌博客主页:折纸花满衣 🏠个人专栏:QT 目录 👉🏻QTheadrun() 👉🏻QMutex👉🏻QWaitCondition👉🏻Q…...

java和kotlin 可以同时运行吗
Java 和 Kotlin 可以同时运行在同一个项目中,这主要得益于 Kotlin 对 Java 的互操作性。Kotlin 被设计为与 Java 100% 兼容,这意味着 Kotlin 代码可以很容易地调用 Java 代码,反之亦然。这种设计使得 Kotlin 能够无缝集成到现有的 Java 项目中…...

2024最新版 Tuxera NTFS for Mac 2023绿色版图文安装教程
在数字化时代,数据的存储和传输变得至关重要。Mac用户经常需要在Windows NTFS格式的移动硬盘上进行读写操作,然而,由于MacOS系统默认不支持NTFS的写操作,这就需要我们寻找一款高效的读写软件。Tuxera NTFS for Mac 2023便是其中…...

npm发布插件超级简单版
在开源的世界里,每个人都有机会成为贡献者,甚至是创新的引领者。您是否有过这样的想法:开发一个解决特定问题的小工具,让他成为其他开发者手中的利器?今天,我们就来一场实战训练,学习如何将你的…...

C# 访问Access存取图片
图片存入ole字段,看有的代码是获取图片的字节数组转换为base64字符串,存入数据库;显示图片是把base64字符串转换为字节数组再显示;直接存字节数组可能还好一点; 插入的时候用带参数的sql写法比较好;用拼接…...

正则表达式中常见字符的用法介绍
正则表达式(Regular Expression,简称Regex)是一种文本模式描述的方法,包括普通字符(如a到z之间的字母)和特殊字符(称为“元字符”)。正则表达式使用单个字符串来描述、匹配一系列符合…...

Vue3.0组合式API:依赖注入provide和inject实现跨层组件的通信
Vue3.0组合式API系列文章: 《Vue3.0组合式API:setup()函数》 《Vue3.0组合式API:使用reactive()、ref()创建响应式代理对象》 《Vue3.0组合式API:computed计算属性、watch监听器、watchEffect高级监听器》 《Vue3.0组合式API&…...

VSCode中配置C/C++环境
在Visual Studio Code(VSCode)中配置C/C环境是一个相对直接且功能强大的过程,它能让开发者利用VSCode的诸多便利功能来编写、编译和调试C/C代码。以下是一个详细的步骤指南,涵盖了从安装必要的软件到配置编译器、调试器以及VSCode…...

vue实现鼠标滚轮控制页面横向滑动
先看效果 20240919_095531 1.首先创建一个xScroll.vue组件 <template><div class"main" v-size-ob"mainSize"><div class"v-scroll"><div class"content"><slot></slot></div></div>…...

【Git使用】删除Github仓库中的指定文件/文件夹
前言: 上篇文章带大家上传了第一个项目至github,那要是想删除仓库中的指定文件夹怎么办?在Github中 仓库是无法通过鼠标操作直接删除文件和文件夹的,那只能通过 git 命令来执行删除操作。接下来就带大家进行操作。 详细步骤: 一…...

Iptables命令常用命令
前言:下是一些非常实用的 iptables 命令合集,涵盖网络攻击防护和日常网络安全防护 1. 查看当前规则 iptables -L -v -n查看现有的所有规则,-v 显示详细信息,-n 禁止解析IP地址和端口以加快显示速度。 2. 清空所有规则 iptables -F清除所有已…...

前端开发之原型模式
介绍 原型模式本质就是借用一个已有的实例做原型,在这原型基础上快速复制出一个和原型一样的一个对象。 class CloneDemo {name clone democlone(): CloneDemo {return new CloneDemo()} } 原型原型链 函数(class)都有显示原型 prototyp…...

分布式缓存服务Redis版解析与配置方式
一、Redis分布式缓存服务概述 Redis是一款高性能的键值对(Key-Value)存储系统,通常用作分布式缓存服务。它基于内存运行,支持丰富的数据类型,并具备高并发、低延迟的特点,非常适合用于缓存需要频繁访问的数…...

WordPress建站钩子函数及使用
目录 前言: 使用场景: 一、常用的wordpress钩子(动作钩子、过滤器钩子) 1、动作钩子(Action Hooks) 2、过滤器钩子(Filter Hooks) 二、常用钩子示例 1、添加自定义 CSS 和 JS…...