深度学习笔记(8)预训练模型
深度学习笔记(8)预训练模型
文章目录
- 深度学习笔记(8)预训练模型
- 一、预训练模型构建
- 一、微调模型,训练自己的数据
- 1.导入数据集
- 2.数据集处理方法
- 3.完形填空训练
- 使用分词器将文本转换为模型的输入格式
- 参数 return_tensors="pt" 表示返回PyTorch张量格式
- 执行模型预测
一、预训练模型构建
加载模型和之前一样,用别人弄好的
# 导入warnings模块,用于忽略后续代码中可能出现的警告信息
import warnings
# 设置warnings模块,使其忽略所有警告
warnings.filterwarnings("ignore")# 从transformers库中导入AutoModelForMaskedLM类,该类用于预训练的掩码语言模型
from transformers import AutoModelForMaskedLM# 指定模型检查点,这里使用的是distilbert-base-uncased模型
model_checkpoint = "distilbert-base-uncased"
# 使用from_pretrained方法加载预训练的模型,该方法将从指定的检查点加载模型
model = AutoModelForMaskedLM.from_pretrained(model_checkpoint)
咱们的任务就是去预测MASK到底是个啥
text = "This is a great " # 定义一个文本字符串# 从指定的模型检查点加载分词器
# model_checkpoint 是之前定义的模型检查点路径,用于加载与模型配套的分词器
tokenizer = AutoTokenizer.from_pretrained(model_checkpoint)# 使用分词器将文本转换为模型的输入格式
# 参数 return_tensors="pt" 表示返回PyTorch张量格式
inputs = tokenizer(text, return_tensors="pt")# inputs 现在是一个包含模型输入的张量或字典,可以用于模型推理
{'input_ids': tensor([[ 101, 2023, 2003, 1037, 2307, 103, 1012, 102]]), 'attention_mask': tensor([[1, 1, 1, 1, 1, 1, 1, 1]])}
下面的代码可以看到mask的id是103
tokenizer.mask_token_id
103
一、微调模型,训练自己的数据
1.导入数据集
from datasets import load_dataset # 从datasets库中导入load_dataset函数imdb_dataset = load_dataset("imdb") # 使用load_dataset函数加载IMDB数据集
这段代码是使用 datasets 库来加载 IMDB 数据集。IMDB 数据集是一个用于情感分析的经典数据集,包含了两类电影评论:正面和负面。
本身是带标签的 正面和负面
- 0 表示negative
- 1表示positive
先查看下数据集的数据
sample = imdb_dataset["train"].shuffle(seed=42).select(range(3)) # 从训练数据中随机选择3个样本for row in sample:print(f"\n'>>> Review: {row['text']}'") # 打印样本的文本内容print(f"'>>> Label: {row['label']}'") # 打印样本的标签
其中一个如下
'>>> Review: This movie is a great. The plot is very true to the book which is a classic written by Mark Twain. The movie starts of with a scene where Hank sings a song with a bunch of kids called "when you stub your toe on the moon" It reminds me of Sinatra's song High Hopes, it is fun and inspirational. The Music is great throughout and my favorite song is sung by the King, Hank (bing Crosby) and Sir "Saggy" Sagamore. OVerall a great family movie or even a great Date movie. This is a movie you can watch over and over again. The princess played by Rhonda Fleming is gorgeous. I love this movie!! If you liked Danny Kaye in the Court Jester then you will definitely like this movie.'
'>>> Label: 1'
但是我们要做完形填空,标签是没用的
2.数据集处理方法
这里文本长度要统一
- 计算每一个文本的长度(word_ids)
- 指定chunk_size,然后将所有数据按块进行拆分,比如每块128个,句子是700字节,要分成128,128,。。。。这种
先定义个函数,这样下面使用可以直接调用
def tokenize_function(examples):# 调用分词器(tokenizer)的函数,传入输入的文本数据集result = tokenizer(examples["text"])# 如果分词器支持快速模式,则生成单词索引(word_ids)if tokenizer.is_fast:result["word_ids"] = [result.word_ids(i) for i in range(len(result["input_ids"]))]# 返回转换后的结果return result
进行文本处理
tokenized_datasets = imdb_dataset.map(tokenize_function, # 应用tokenize_function函数到每个样本batched=True, # 是否将数据分成批次处理remove_columns=["text", "label"] # 要从数据集中移除的列
)
首先,imdb_dataset.map() 方法被用来应用 tokenize_function 函数到 imdb_dataset 的 train 部分。这个方法会对数据集中的每个样本应用指定的函数,并返回一个新的数据集,其中包含应用函数后的结果。
batched=True 参数告诉 map 方法将输入数据分成批次进行处理。这通常是为了提高效率,尤其是在处理大型数据集时。
remove_columns=[“text”, “label”] 参数告诉 map 方法在处理数据时移除指定的列。在这个例子中,它移除了 text 和 label 列,因为 text 列已经被处理为模型的输入,而 label 列不再需要,因为咱们是完形填空任务,不需要标签。
然后进行切分
tokenizer.model_max_length
chunk_size = 128
tokenizer.model_max_length 是一个属性,它表示模型能够接受的最大输入长度。这个属性通常用于序列标注任务,以确保输入的长度不超过模型的最大接受长度。
chunk_size 是一个参数,用于指定将输入序列分割成小块的大小。这个参数通常用于处理过长的输入序列,以便将其分割成多个小块,然后分别处理这些小块。
因为上限是512 所以你切分要是64 128 256这种
切分的时候也可以先看下文本长度
# 看看每一个都多长
tokenized_samples = tokenized_datasets["train"][:3]for idx, sample in enumerate(tokenized_samples["input_ids"]):print(f"'>>> Review {idx} length: {len(sample)}'")
'>>> Review 0 length: 363'
'>>> Review 1 length: 304'
'>>> Review 2 length: 133'
这里看出128切分比较合适
先拿着三个试试
concatenated_examples = {k: sum(tokenized_samples[k], []) for k in tokenized_samples.keys()#计算拼一起有多少个,
}
total_length = len(concatenated_examples["input_ids"])
print(f"'>>> Concatenated reviews length: {total_length}'")
'>>> Concatenated reviews length: 800'
下面进行切分
chunks = {# 使用字典推导式(dict comprehension)创建一个新的字典,其中键是序列的键(k),值是分割后的块k: [t[i : i + chunk_size] for i in range(0, total_length, chunk_size)]# 遍历concatenated_examples字典中的每个键值对for k, t in concatenated_examples.items()
}for chunk in chunks["input_ids"]:# 打印每个chunk的长度print(f"'>>> Chunk length: {len(chunk)}'")
'>>> Chunk length: 128'
'>>> Chunk length: 128'
'>>> Chunk length: 128'
'>>> Chunk length: 128'
'>>> Chunk length: 128'
'>>> Chunk length: 128'
'>>> Chunk length: 32'
最后那个不够数,直接给他丢弃
def group_texts(examples):# 将所有的文本实例拼接到一起concatenated_examples = {k: sum(examples[k], []) for k in examples.keys()}# 计算拼接后的总长度total_length = len(concatenated_examples[list(examples.keys())[0]])# 使用整除运算符(//)计算每个chunk的长度,然后乘以chunk_size,以确保不会出现多余的文本total_length = (total_length // chunk_size) * chunk_size# 根据计算出的总长度,对拼接后的文本进行切分result = {# 使用字典推导式(dict comprehension)创建一个新的字典,其中键是原始字典的键(k),值是切分后的块k: [t[i : i + chunk_size] for i in range(0, total_length, chunk_size)]# 遍历concatenated_examples字典中的每个键值对for k, t in concatenated_examples.items()}# 如果完型填空任务需要使用标签,则将标签复制到结果字典中result["labels"] = result["input_ids"].copy()# 返回分割后的结果return result
使用整除运算符(//)计算每个chunk的长度,然后乘以chunk_size这个就是说如果小于128 整除后就得0,就没了
train: Dataset({features: ['input_ids', 'attention_mask', 'word_ids', 'labels'],num_rows: 61291})test: Dataset({features: ['input_ids', 'attention_mask', 'word_ids', 'labels'],num_rows: 59904})unsupervised: Dataset({features: ['input_ids', 'attention_mask', 'word_ids', 'labels'],num_rows: 122957})
})
会发现数据量大了
3.完形填空训练
from transformers import DataCollatorForLanguageModeling # 从transformers库中导入DataCollatorForLanguageModeling类data_collator = DataCollatorForLanguageModeling(tokenizer=tokenizer, mlm_probability=0.15) # 创建一个数据收集器
然后再举个例子看看mask长啥样
samples = [lm_datasets["train"][i] for i in range(2)] # 从训练数据中选择2个样本
for sample in samples:_ = sample.pop("word_ids") # 移除样本中的"word_ids"键,因为发现没啥用print(sample)for chunk in data_collator(samples)["input_ids"]:print(f"\n'>>> {tokenizer.decode(chunk)}'")#tokenizer.decode(chunk)是为了让被隐藏的更明显点,所以直接decode出来print(len(chunk))
训练过程
from transformers import TrainingArguments # 从transformers库中导入TrainingArguments类batch_size = 64 # 定义训练和评估时的批次大小
# 计算每个epoch打印结果的步数
logging_steps = len(downsampled_dataset["train"]) // batch_size
model_name = model_checkpoint.split("/")[-1] # 从模型检查点路径中提取模型名称training_args = TrainingArguments(output_dir=f"{model_name}-finetuned-imdb", # 指定输出目录,其中包含微调后的模型和日志文件overwrite_output_dir=True, # 是否覆盖现有的输出目录evaluation_strategy="epoch", # 指定评估策略,这里为每个epoch评估一次learning_rate=2e-5, # 学习率weight_decay=0.01, # 权重衰减系数per_device_train_batch_size=batch_size, # 每个GPU的训练批次大小per_device_eval_batch_size=batch_size, # 每个GPU的评估批次大小logging_steps=logging_steps, # 指定每个epoch打印结果的步数num_train_epochs=1, # 指定训练的epoch数量save_strategy='epoch', # 指定保存策略,这里为每个epoch保存一次模型
)
生成的模型
from transformers import Trainer # 从transformers库中导入Trainer类trainer = Trainer(model=model, # 指定要训练的模型实例args=training_args, # 指定训练参数对象train_dataset=downsampled_dataset["train"], # 指定训练数据集eval_dataset=downsampled_dataset["test"], # 指定评估数据集data_collator=data_collator, # 指定数据收集器
)
评估标准使用困惑度
import math # 导入math模块,用于计算对数和指数eval_results = trainer.evaluate() # 使用Trainer的evaluate方法评估模型print(f">>> Perplexity: {math.exp(eval_results['eval_loss']):.2f}") # 打印评估结果中的对数
人话就是你不得在mask那挑啥词合适吗,平均挑了多少个才能答对
训练模型
trainer.train()
eval_results = trainer.evaluate()
print(f">>> Perplexity: {math.exp(eval_results['eval_loss']):.2f}")
发现困惑度降低了
from transformers import AutoModelForMaskedLMmodel_checkpoint = "./distilbert-base-uncased"
model = AutoModelForMaskedLM.from_pretrained("./distilbert-base-uncased-finetuned-imdb/checkpoint-157")
加载自己的模型
import torch # 导入torch库,用于处理张量
使用分词器将文本转换为模型的输入格式
参数 return_tensors=“pt” 表示返回PyTorch张量格式
inputs = tokenizer(text, return_tensors=“pt”)
执行模型预测
# model 是一个预训练的BERT模型实例
# **inputs 表示将inputs字典中的所有键值对作为关键字参数传递给model
token_logits = model(**inputs).logits# 找到遮蔽词在输入中的索引
# inputs["input_ids"] 是模型输入的词汇索引张量
# tokenizer.mask_token_id 是遮蔽词的词汇索引
mask_token_index = torch.where(inputs["input_ids"] == tokenizer.mask_token_id)[1]# 获取遮蔽词的预测logits
# mask_token_index 是遮蔽词在输入中的索引张量
# token_logits 是模型输出的预测logits张量
mask_token_logits = token_logits[0, mask_token_index, :]# 找到前5个最可能的替换词
# torch.topk 函数用于找到最大k个值及其索引
# dim=1 表示在第二个维度(即词汇维度)上进行排序
top_5_tokens = torch.topk(mask_token_logits, 5, dim=1).indices[0].tolist()# 打印每个最可能的替换词及其替换后的文本
for token in top_5_tokens:# 使用 tokenizer.decode 方法将索引转换为文本# text 是原始文本# tokenizer.mask_token 是遮蔽词# [token] 是替换词的索引列表print(f"'>>> {text.replace(tokenizer.mask_token, tokenizer.decode([token]))}'")
'>>> This is a great deal.'
'>>> This is a great idea.'
'>>> This is a great adventure.'
'>>> This is a great film.'
'>>> This is a great movie.'
可以看到结果和上面的通用的不一样了,这里是film movie这些针对你的训练数据的了--------
相关文章:

深度学习笔记(8)预训练模型
深度学习笔记(8)预训练模型 文章目录 深度学习笔记(8)预训练模型一、预训练模型构建一、微调模型,训练自己的数据1.导入数据集2.数据集处理方法3.完形填空训练 使用分词器将文本转换为模型的输入格式参数 return_tenso…...
C#事件的用法
前言 在C#中,事件(Event)可以实现当类内部发生某些特定的事情时,它可以通知其他类或对象。事件是基于委托(Delegate)的,委托是一种类型安全的函数指针,它定义了方法的类型ÿ…...

金砖软件测试赛项之Jmeter如何录制脚本!
一、简介 Apache JMeter 是一款开源的性能测试工具,用于测试各种服务的负载能力,包括Web应用、数据库、FTP服务器等。它可以模拟多种用户行为,生成负载以评估系统的性能和稳定性。 JMeter 的主要特点: 图形用户界面:…...
docker-squash镜像压缩
docker-squash 和 docker export docker load 的原理和效果有一些相似之处,但它们的工作方式和适用场景有所不同。 docker-squash docker-squash 是一个工具,它通过分析 Docker 镜像的层(layers)并将其压缩成更少的层来减小镜像…...
Vue3快速入门+axios的异步请求(基础使用)
学习Vue之前先要学习htmlcssjs的基础使用 Vue其实是js的框架 常用到的Vue指令包括vue-on,vue-for,vue-blind,vue-if&vue-show,v-modul vue的基础模板: <!DOCTYPE html> <html lang"zh-CN"> <head><meta charset"UTF-8&…...

VM16安装macOS11
注意: 本文内容于 2024-09-17 12:08:24 创建,可能不会在此平台上进行更新。如果您希望查看最新版本或更多相关内容,请访问原文地址:VM16安装macOS11。感谢您的关注与支持! 使用 Vmware Workstation Pro 16 安装 macOS…...

自定义复杂AntV/G6案例
一、效果图 二、源码 /** * * Author: me * CreatDate: 2024-08-22 * * Description: 复杂G6案例 * */ <template><div class"moreG6-wapper"><div id"graphContainer" ref"graphRef" class"graph-content"></d…...

Golang | Leetcode Golang题解之第419题棋盘上的战舰
题目: 题解: func countBattleships(board [][]byte) (ans int) {for i, row : range board {for j, ch : range row {if ch X && !(i > 0 && board[i-1][j] X || j > 0 && board[i][j-1] X) {ans}}}return }...
CCF刷题计划——LDAP(交集、并集 how to go)
LDAP 计算机软件能力认证考试系统 不知道为什么,直接给我报一个运行错误,得了0分。但是我在Dev里,VS里面都跑的好好的,奇奇怪怪。如果有大佬路过,请帮小弟看看QWQ。本题学到的:交集set_intersection、并集…...

谷歌论文提前揭示o1模型原理:AI大模型竞争或转向硬件
Open AI最强模型o1的护城河已经没有了?仅在OpenAI发布最新推理模型o1几日之后,海外社交平台 Reddit 上有网友发帖称谷歌Deepmind在 8 月发表的一篇论文内容与o1模型原理几乎一致,OpenAI的护城河不复存在。 谷歌DeepMind团队于今年8月6日发布…...

【ShuQiHere】 探索数据挖掘的世界:从概念到应用
🌐 【ShuQiHere】 数据挖掘(Data Mining, DM) 是一种从大型数据集中提取有用信息的技术,无论是在商业分析、金融预测,还是医学研究中,数据挖掘都扮演着至关重要的角色。本文将带您深入了解数据挖掘的核心概…...

LabVIEW提高开发效率技巧----使用事件结构优化用户界面响应
事件结构(Event Structure) 是 LabVIEW 中用于处理用户界面事件的强大工具。通过事件驱动的编程方式,程序可以在用户操作时动态执行特定代码,而不是通过轮询(Polling)的方式不断检查界面控件状态。这种方式…...
【前端】ES6:Set与Map
文章目录 1 Set结构1.1 初识Set1.2 实例的属性和方法1.3 遍历1.4 复杂数据结构去重 2 Map结构2.1 初识Map2.2 实例的属性和方法2.3 遍历 1 Set结构 它类似于数组,但成员的值都是唯一的,没有重复的值。 1.1 初识Set let s1 new Set([1, 2, 3, 2, 3]) …...
Java 之网络编程小案例
1. 多发多收 描述: 编写一个简单的聊天程序,客户端可以向服务器发送多条消息,服务器可以接收所有消息并回复。 代码示例: 服务器端 (Server.java): import java.io.*; import java.net.*; import java.util.concurrent.Execut…...
Spring Boot:现代化Java应用开发的艺术
目录 什么是Spring Boot? 为什么选择Spring Boot? Spring Boot的核心概念 详细步骤:创建一个Spring Boot应用 步骤1:使用Spring Initializr创建项目 步骤2:解压并导入项目 步骤3:构建和配置项目 po…...
Redis五种基本数据结构的使用
Redis具有五种基本数据类型:String(字符串)、Hash(哈希)、List(列表)、Set(集合)、SortedSet(有序集合),下面示意它们的使用。 String类数据类型的使用 增:添加数据(set)、添加多个数据(mset)、添加数据时指定过期时间(setex) 删…...

【QT】系统-下
欢迎来到Cefler的博客😁 🕌博客主页:折纸花满衣 🏠个人专栏:QT 目录 👉🏻QTheadrun() 👉🏻QMutex👉🏻QWaitCondition👉🏻Q…...
java和kotlin 可以同时运行吗
Java 和 Kotlin 可以同时运行在同一个项目中,这主要得益于 Kotlin 对 Java 的互操作性。Kotlin 被设计为与 Java 100% 兼容,这意味着 Kotlin 代码可以很容易地调用 Java 代码,反之亦然。这种设计使得 Kotlin 能够无缝集成到现有的 Java 项目中…...

2024最新版 Tuxera NTFS for Mac 2023绿色版图文安装教程
在数字化时代,数据的存储和传输变得至关重要。Mac用户经常需要在Windows NTFS格式的移动硬盘上进行读写操作,然而,由于MacOS系统默认不支持NTFS的写操作,这就需要我们寻找一款高效的读写软件。Tuxera NTFS for Mac 2023便是其中…...
npm发布插件超级简单版
在开源的世界里,每个人都有机会成为贡献者,甚至是创新的引领者。您是否有过这样的想法:开发一个解决特定问题的小工具,让他成为其他开发者手中的利器?今天,我们就来一场实战训练,学习如何将你的…...
Java 语言特性(面试系列1)
一、面向对象编程 1. 封装(Encapsulation) 定义:将数据(属性)和操作数据的方法绑定在一起,通过访问控制符(private、protected、public)隐藏内部实现细节。示例: public …...

汽车生产虚拟实训中的技能提升与生产优化
在制造业蓬勃发展的大背景下,虚拟教学实训宛如一颗璀璨的新星,正发挥着不可或缺且日益凸显的关键作用,源源不断地为企业的稳健前行与创新发展注入磅礴强大的动力。就以汽车制造企业这一极具代表性的行业主体为例,汽车生产线上各类…...

Linux-07 ubuntu 的 chrome 启动不了
文章目录 问题原因解决步骤一、卸载旧版chrome二、重新安装chorme三、启动不了,报错如下四、启动不了,解决如下 总结 问题原因 在应用中可以看到chrome,但是打不开(说明:原来的ubuntu系统出问题了,这个是备用的硬盘&a…...
拉力测试cuda pytorch 把 4070显卡拉满
import torch import timedef stress_test_gpu(matrix_size16384, duration300):"""对GPU进行压力测试,通过持续的矩阵乘法来最大化GPU利用率参数:matrix_size: 矩阵维度大小,增大可提高计算复杂度duration: 测试持续时间(秒&…...
【学习笔记】深入理解Java虚拟机学习笔记——第4章 虚拟机性能监控,故障处理工具
第2章 虚拟机性能监控,故障处理工具 4.1 概述 略 4.2 基础故障处理工具 4.2.1 jps:虚拟机进程状况工具 命令:jps [options] [hostid] 功能:本地虚拟机进程显示进程ID(与ps相同),可同时显示主类&#x…...
IP如何挑?2025年海外专线IP如何购买?
你花了时间和预算买了IP,结果IP质量不佳,项目效率低下不说,还可能带来莫名的网络问题,是不是太闹心了?尤其是在面对海外专线IP时,到底怎么才能买到适合自己的呢?所以,挑IP绝对是个技…...

uniapp 开发ios, xcode 提交app store connect 和 testflight内测
uniapp 中配置 配置manifest 文档:manifest.json 应用配置 | uni-app官网 hbuilderx中本地打包 下载IOS最新SDK 开发环境 | uni小程序SDK hbulderx 版本号:4.66 对应的sdk版本 4.66 两者必须一致 本地打包的资源导入到SDK 导入资源 | uni小程序SDK …...
【Kafka】Kafka从入门到实战:构建高吞吐量分布式消息系统
Kafka从入门到实战:构建高吞吐量分布式消息系统 一、Kafka概述 Apache Kafka是一个分布式流处理平台,最初由LinkedIn开发,后成为Apache顶级项目。它被设计用于高吞吐量、低延迟的消息处理,能够处理来自多个生产者的海量数据,并将这些数据实时传递给消费者。 Kafka核心特…...
day51 python CBAM注意力
目录 一、CBAM 模块简介 二、CBAM 模块的实现 (一)通道注意力模块 (二)空间注意力模块 (三)CBAM 模块的组合 三、CBAM 模块的特性 四、CBAM 模块在 CNN 中的应用 一、CBAM 模块简介 在之前的探索中…...
组合模式:构建树形结构的艺术
引言:处理复杂对象结构的挑战 在软件开发中,我们常遇到需要处理部分-整体层次结构的场景: 文件系统中的文件与文件夹GUI中的容器与组件组织结构中的部门与员工菜单系统中的子菜单与菜单项组合模式正是为解决这类问题而生的设计模式。它允许我们将对象组合成树形结构来表示&…...