当前位置: 首页 > news >正文

【NLP相关】基于现有的预训练模型使用领域语料二次预训练


❤️觉得内容不错的话,欢迎点赞收藏加关注😊😊😊,后续会继续输入更多优质内容❤️

👉有问题欢迎大家加关注私戳或者评论(包括但不限于NLP算法相关,linux学习相关,读研读博相关......)👈

BERT

【NLP相关】基于现有的预训练模型使用领域语料二次预训练

在自然语言处理领域,预训练模型已经成为了最为热门和有效的技术之一。预训练模型通过在大规模文本语料库上进行训练,可以学习到通用的语言模型,然后可以在不同的任务上进行微调。但是,预训练模型在领域特定任务上的表现可能不够好,因为它们是在通用语言语料库上进行训练的。为了提高在特定领域的任务中的性能,我们可以使用领域语料库对预训练模型进行二次预训练。

本篇博客将介绍如何基于现有的预训练模型使用领域语料二次预训练。我们将以 PyTorch 和 Transformers 库为基础,以医学文本分类任务为例,来详细说明二次预训练的过程。

1. 模型介绍

在本篇博客中,我们使用的预训练模型是 BERT(Bidirectional Encoder Representations from Transformers)。BERT 是一种基于 Transformer 的预训练模型,由 Google 团队开发。它在多个自然语言处理任务上取得了最先进的结果,例如文本分类、命名实体识别和问答系统等。

BERT 模型是一种双向的 Transformer 模型,能够有效地处理自然语言序列。它将文本输入嵌入到向量空间中,并在此基础上进行自监督训练,以学习通用的语言表示。在预训练完成后,BERT 模型可以进行微调,以适应不同的自然语言处理任务。

2. 代码实现

2.1 数据预处理

在开始二次预训练之前,我们需要准备领域特定的语料库。在这里,我们使用的是医学文本分类数据集,其中包含了一些医学文章的标题和摘要,并且每个文本都被标记为一个预定义的类别。

首先,我们需要将原始文本数据拆分为单个句子,并将其标记化处理。我们可以使用 Hugging Face 的 tokenizer 来完成这个任务。

from transformers import BertTokenizer# 加载 BERT tokenizer
tokenizer = BertTokenizer.from_pretrained("bert-base-uncased")class MedicalDataset(Dataset):def __init__(self, tokens, max_length=128):self.tokens = tokensself.max_length = max_lengthdef __len__(self):return len(self.tokens)def __getitem__(self, idx):# 获取句子对tokens = self.tokens[idx]# 将句子对拼接成一个序列,并将其标记化处理input_ids = tokenizer.encode(tokens[0], tokens[1],add_special_tokens=True, max_length=self.max_length,truncation_strategy='longest_first')attention_mask = [1] * len(input_ids)# 填充序列长度padding = [0] * (self.max_length - len(input_ids))input_ids += paddingattention_mask += padding# 返回 input_ids 和 attention_maskreturn torch.LongTensor(input_ids), torch.LongTensor(attention_mask)

2.2 二次预训练

在数据预处理之后,我们可以开始进行二次预训练了。在这里,我们将使用 Hugging Face 的 Transformers 库,以及 PyTorch 框架来实现二次预训练。

首先,我们需要加载预训练的 BERT 模型。在这里,我们使用的是 bert-base-uncased 模型,它是一个基于英文的预训练模型。我们还需要定义一些训练参数,例如学习率和批大小等。

from transformers import BertForPreTraining, AdamW
from torch.utils.data import DataLoader# 加载预训练的 BERT 模型
model = BertForPreTraining.from_pretrained('bert-base-uncased')# 定义训练参数
epochs = 3
batch_size = 16
learning_rate = 2e-5
device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')# 将模型移动到 GPU 上
model.to(device)

接下来,我们需要加载领域特定的语料库,并将其转换为 PyTorch 数据集。在这里,我们使用的是 PyTorch 中的 Dataset 类。我们还需要将数据集加载到 PyTorch 的数据加载器中,以便进行训练。

# 加载领域特定的语料库
with open('medical_data.txt') as f:sentences = f.readlines()# 将语料库转换为 PyTorch 数据集
dataset = MedicalDataset(sentences)# 将数据集加载到 PyTorch 的数据加载器中
loader = DataLoader(dataset, batch_size=batch_size, shuffle=True)

在准备好数据集之后,我们可以开始训练模型了。我们将使用 AdamW 优化器和交叉熵损失函数来训练模型。在每个 epoch 完成之后,我们会对模型进行一次测试,并计算准确率和损失函数值。最后,我们将保存训练好的模型。

# 定义优化器和损失函数
optimizer = AdamW(model.parameters(), lr=learning_rate)
criterion = torch.nn.CrossEntropyLoss()# 训练模型
for epoch in range(epochs):total_loss = 0total_correct = 0total_samples = 0# 遍历数据集for i, batch in enumerate(loader):# 将输入数据和标签移动到 GPU 上input_ids, attention_mask = batchinput_ids = input_ids.to(device)attention_mask = attention_mask.to(device)# 将模型设置为训练模式model.train()# 计算模型的输出outputs = model(input_ids, attention_mask=attention_mask)# 计算损失函数值loss = criterion(outputs.logits.view(-1, 2), outputs.labels.view(-1))# 清除之前的梯度optimizer.zero_grad()# 反向传播和优化loss.backward()optimizer.step()# 统计训练信息total_loss += loss.item()total_samples += input_ids.size(0)total_correct += torch.sum(torch.argmax(outputs.logits, dim=-1) == outputs.labels.view(-1)).item()# 输出训练信息if (i + 1) % 100 == 0:print('Epoch [%d/%d], Step [%d/%d], Loss: %.4f, Accuracy: %.2f%%'% (epoch + 1, epochs, i + 1, len(loader),total_loss / total_samples, total_correct / total_samples * 100))# 在每个 epoch 完成之后进行一次测试
with torch.no_grad():total_loss = 0total_correct = 0total_samples = 0# 将模型设置为评估模式model.eval()# 遍历测试数据集for i, batch in enumerate(test_loader):# 将输入数据和标签移动到 GPU 上input_ids, attention_mask = batchinput_ids = input_ids.to(device)attention_mask = attention_mask.to(device)# 计算模型的输出outputs = model(input_ids, attention_mask=attention_mask)# 计算损失函数值loss = criterion(outputs.logits.view(-1, 2), outputs.labels.view(-1))# 统计测试信息total_loss += loss.item()total_samples += input_ids.size(0)total_correct += torch.sum(torch.argmax(outputs.logits, dim=-1) == outputs.labels.view(-1)).item()# 输出测试信息print('Epoch [%d/%d], Test Loss: %.4f, Test Accuracy: %.2f%%'% (epoch + 1, epochs, total_loss / total_samples, total_correct / total_samples * 100))#保存训练好的模型torch.save(model.state_dict(), 'medical_bert.pth')

3. 案例解析

假设我们要对医学领域中的文本进行二次预训练。我们可以使用已经预训练好的 BERT 模型,并使用医学领域的语料库进行二次预训练。

首先,我们需要将医学领域的语料库进行预处理。在这里,我们可以使用 NLTK 库来进行分词和词形还原等操作。我们还可以将语料库中的每个句子转换为 BERT 输入格式。

import nltk
from nltk.tokenize import word_tokenize
from nltk.stem import WordNetLemmatizer
from transformers import BertTokenizer# 加载 BERT 分词器
tokenizer = BertTokenizer.from_pretrained('bert-base-uncased')# 加载 NLTK 分词器和词形还原
nltk.download('punkt')
nltk.download('wordnet')
lemmatizer = WordNetLemmatizer()加载医学领域的语料库
with open('medical_corpus.txt', 'r') as f:corpus = f.read()#对每个句子进行分词、词形还原和转换为 BERT 输入格式
sentences = []
for sentence in nltk.sent_tokenize(corpus):words = nltk.word_tokenize(sentence)words = [lemmatizer.lemmatize(word) for word in words]words = [word.lower() for word in words]tokens = tokenizer.encode_plus(words,add_special_tokens=True,max_length=512,padding='max_length',truncation=True)sentences.append((tokens['input_ids'], tokens['attention_mask']))

接下来,我们需要使用这些句子对 BERT 模型进行二次预训练。为此,我们需要定义一个新的数据加载器,将这些句子传递给模型进行训练。

from torch.utils.data import Dataset, DataLoaderclass MedicalDataset(Dataset):def __init__(self, sentences):self.sentences = sentencesdef __len__(self):return len(self.sentences)def __getitem__(self, idx):return self.sentences[idx]# 定义数据加载器
loader = DataLoader(MedicalDataset(sentences),batch_size=16,shuffle=True)

现在,我们可以开始对 BERT 模型进行二次预训练了。我们可以使用与之前相同的训练代码。

# 定义训练函数
def train(model, loader, optimizer, device):model.train()for batch in loader:input_ids = batch[0].to(device)attention_mask = batch[1].to(device)optimizer.zero_grad()loss, _ = model(input_ids=input_ids,attention_mask=attention_mask,output_hidden_states=True)[:2]loss.backward()optimizer.step()# 加载预训练的 BERT 模型
model = BertForMaskedLM.from_pretrained('bert-base-uncased')# 设置设备
device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
model.to(device)# 定义优化器
optimizer = AdamW(model.parameters(), lr=5e-5)# 进行二次预训练
for epoch in range(num_epochs):train(model, train_loader, optimizer, device)# 每个 epoch 结束后测试模型的性能perplexity = evaluate(model, test_loader, device)print(f'Epoch {epoch+1}, perplexity: {perplexity:.3f}')# 保存模型model_path = f'model_epoch{epoch+1}.pt'torch.save(model.state_dict(), model_path)

这里定义了一个 train 函数来训练模型。这个函数接收一个模型、一个数据加载器、一个优化器和一个设备作为输入。它会将模型设为训练模式,并且在每个批次上运行前向传播、计算损失、计算梯度和更新参数。

接下来,我们加载预训练的 BERT 模型,并将其移动到所选设备上。我们使用 AdamW 优化器,并将学习率设置为 5e-5。

最后,我们使用一个简单的 for 循环来进行二次预训练。在每个 epoch 结束时,我们会在测试集上评估模型,并打印出 perplexity 指标。我们还会将模型保存在磁盘上,以便以后进行检索。


❤️觉得内容不错的话,欢迎点赞收藏加关注😊😊😊,后续会继续输入更多优质内容❤️

👉有问题欢迎大家加关注私戳或者评论(包括但不限于NLP算法相关,linux学习相关,读研读博相关......)👈

相关文章:

【NLP相关】基于现有的预训练模型使用领域语料二次预训练

❤️觉得内容不错的话,欢迎点赞收藏加关注😊😊😊,后续会继续输入更多优质内容❤️👉有问题欢迎大家加关注私戳或者评论(包括但不限于NLP算法相关,linux学习相关,读研读博…...

使用git进行项目管理--git使用及其常用命令

使用git进行项目管理 文章目录 使用git进行项目管理git使用1.添加用户名字2.添加用户邮箱3.git初始化4.add5.commit6.添加到gitee仓库7.推送到gitee8.切换版本git常用命令git add把指定的文件添加到暂存区中添加所有修改、已删除的文件到暂存区中添加所有修改、已删除、新增的文…...

Mybatis_CRUD使用

目录1 Mybatis简介环境说明:预备知识:1.1 定义1.2 持久化为什么需要持久化服务呢?1.3 持久层1.4 为什么需要Mybatis2 依赖配置3 CRUDnamespaceselect (查询用户数据)※传值方式:于方法中传值使用Map传值insert (插入用…...

JVM的过程内分析和过程间分析有什么区别?

问: 目前所有常见的Java虚拟机对过程间分析的支持都相 当有限,要么借助大规模的方法内联来打通方法间的隔阂,以过程内分析(Intra-Procedural Analysis, 只考虑过程内部语句,不考虑过程调用的分析&#xff…...

LearnDash测验报告如何帮助改进您的课程

某一个场景。Pennywell 大学有一门课程“Introduction to Linear Algebra”。上学期进行了两次测验。20% 的学生在第一次测验中不及格,而 80% 在第二次测验中不及格。在进一步评估中,观察到第一次测验不及格的学生在第二次测验中也不及格。在第二次测验中…...

如何通过Java将Word转换为PDF

Word是我们日常编辑文档内容时十分常用的一种文档格式。但相比之下,PDF文档的格式、布局更为固定,不易被更改。在保存或传输较为重要的文档内容时,PDF文档格式也时很多人的不二选择。很多时候我们都会遇到需要将Word转换为PDF的情况。下面我就…...

DOM型XSS

DOM型XSSDOM是什么DOM型XSSDOM型XSS实操DOM是什么 DOM就是Document。 文档是由节点构成的集合&#xff0c;在DOM里存在许多不同类型的节点&#xff0c;主要有&#xff1a;元素节点、文本节点&#xff0c;属性节点。 元素节点&#xff1a;好比< body >< p >< h …...

04-项目立项:项目方案、可行性分析、产品规划、立项评审

文章目录4.1 项目方案立项阶段4.2 可行性分析4.3 产品规划4.4 立项评审4.4.1 立项说明书的主要内容4.4.2 立项评审流程章节总结4.1 项目方案 学习目标&#xff1a; 能够输出产品项目方案 项目开发设计流程的主要阶段&#xff1a; 立项阶段 → 设计阶段 → 开发阶段 → 测试阶…...

数据分享|NPP VIIRS夜间灯光数据(2012-2020逐月)

2011年10月美国的“索米”国家极轨卫星伙伴卫星(Suomi National Polar-orbiting Partnership or Suomi NPP)发射,它搭载的VIIRS传感器上有一个称为DNB(Day Night Band)的波段能够在500米分辨率(比原来的OLS提高6倍)的尺度上对地表开展每天覆盖全球一次的高灵敏度(比OLS提…...

网络概论笔记

概论 网络研究的是节点和边 移动互联到物联网时代&#xff0c;只有有互联网&#xff0c;网络就不会落伍 协议&#xff1a;对等层面的实体固定的通信规则 协议包括&#xff1a;语法&#xff0c;语义&#xff0c;格式&#xff0c;次序&#xff0c;动作 网络是任意连接的 服务…...

软工2023个人作业二——软件案例分析

项目内容这个作业属于哪个课程2023年北航敏捷软件工程这个作业的要求在哪里个人作业-软件案例分析我在这个课程的目标是学习并掌握现代软件开发和项目管理技术&#xff0c;体验敏捷开发工作流程这个作业在哪个具体方面帮助我实现目标从软件工程角度分析比较我们所熟悉的软件&am…...

python数据分析表格文档Excel数据分析器统计源码

wx供重浩&#xff1a;创享软件 对话框发送&#xff1a;python表格 获取完整源码源文件说明文档可执行文件等 在PyCharm中运行《Excel数据分析师》即可进入如图1所示的系统主界面。在该界面中&#xff0c;通过顶部的工具栏可以选择所要进行的操作。 具体的操作步骤如下&#xff…...

Istio Sidecar启动顺序 - 导致的应用容器网络不通

目录一、问题二、Istio 1.7及其之后版本的解决方案2.1 方式1&#xff1a;安装Istio时全局设置2.2 方式2&#xff1a;在应用Deployment通过annotation设置2.3 holdApplicationUntilProxyStarts启用效果三、Istio 1.7之前的解决方案一、问题 线上应用集成了Spring Cloud K8S Con…...

3696. 构造有向无环图

Powered by:NEFU AB-IN Link 文章目录3696. 构造有向无环图题意思路代码3696. 构造有向无环图 题意 Codeforces Round 656 (Div. 3) E 给定一个由 n个点和 m条边构成的图。 不保证给定的图是连通的。 图中的一部分边的方向已经确定&#xff0c;你不能改变它们的方向。 剩下的边…...

RuoYi-Flowable-Plus(代码生成)

RuoYi-Flowable-Plus搭建 若依所有扩展项目的代码生成功能都是一样的&#xff0c;RuoYi-Flowable-Plus为例来演示。 模块创建 1.创建新模块ruoyi-student2.编辑RuoYi-Flowable-Plus\pom.xml <dependency><groupId>com.ruoyi</groupId><artifactId>ruoy…...

训练CV模型常用的方法与技巧

最近参加一个CV比赛&#xff0c;看到有参赛者分享了自己训练图像识别模型时常用到的小技巧&#xff0c;故对其进行记录、整理&#xff0c;方便未来继续学习。整理了很多&#xff0c;它们不一定每次有用&#xff0c;但请记在心中&#xff0c;说不定未来某个任务它们就发挥了作用…...

[Java·算法·中等]LeetCode22. 括号生成

每天一题&#xff0c;防止痴呆题目示例分析思路1题解1分析思路2题解2分析思路3题解3&#x1f449;️ 力扣原文 题目 数字 n 代表生成括号的对数&#xff0c;请你设计一个函数&#xff0c;用于能够生成所有可能的并且 有效的 括号组合。 示例 输入&#xff1a;n 3 输出&…...

Git项目合并实践

Git项目合并实践 一、前言 环境 操作系统&#xff1a;Windows 10 专业版 代码托管平台&#xff1a;Gitee 场景 同一个项目&#xff0c;在某一个时间点&#xff0c;被另外一个团队拷贝和修改&#xff0c;并且代码不在同一个仓库&#xff0c;最后需要合并项目 不是同一个项…...

C++实战md5、base64算法实现(附源码)

C++常用功能源码系列 文章目录 C++常用功能源码系列前言一、常用加密算法1. md5是什么二、源码1. md52. base64、decode总结前言 本文是C/C++常用功能代码封装专栏的导航贴。部分来源于实战项目中的部分功能提炼,希望能够达到你在自己的项目中拿来就用的效果,这样更好的服务…...

P6专题:P6 EPPM和PPM基本概念

目录 引言 Oracles Primavera P6 Enterprise Project Portfolio Management&#xff08;P6 EPPM&#xff09; Oracles Primavera P6 Professional Project Management 引言 Oracle Primavera系列软件专注于项目密集型企业&#xff0c;其整个项目生命周期内所有项目的组合管…...

【为什么事务@Transactional会失效】

在Spring框架中&#xff0c;Transactional注解用于声明一个方法需要被包含在事务中&#xff0c;以确保数据库操作的一致性和完整性。Transactional注解通常用于Service层或DAO层的方法上。 Transactional注解失效可能是由以下原因引起的&#xff1a; 注解未被正确声明或配置&a…...

NLP中的对话机器人——模型的评估

引言 本文是七月在线《NLP中的对话机器人》的视频笔记&#xff0c;主要介绍FAQ问答型聊天机器人的实现。 模型的评估 我们如何评估模型的好坏&#xff1f;由于我们的数据集没有提供测试数据&#xff0c;所以我们很难评估模型的好 坏。如果我们要做非常严谨的评估&#xff0c…...

数据挖掘知识规整与心得体会

一.大数据的特点&#xff1a; 数据多&#xff0c;类型多&#xff0c;更新快&#xff0c;更新内容多。 二.分类&#xff08;classification&#xff09;与混淆矩阵&#xff08;confusion matrix&#xff09; 这里的分类说的是二分类问题&#xff0c;比如说把人分为好人和坏人&a…...

修正一些formdesigner的一些bug与操作

之前集成了formdesigner表单设计器&#xff0c;但还是有些问题&#xff0c;所以进行一些bug修复与功能修正 一、之前组件布局的图标不见了 在main.js里增加下面一行 import /components/formdesigner/assets/iconfont/iconfont.js 效果如下&#xff1a; 二、选择列表没有数…...

前端网络安全

什么是同源策略同源指的是&#xff1a;协议、端口号、域名必须一致。他是浏览器的一个用于隔离潜在恶意文件的重要安全机制。限制了从同一个源加载的文档或脚本&#xff0c;与另一个源的资源进行交互。同源策略主要限制了三个方面&#xff1a;当前域下的js脚本不能够访问其他域…...

docker内存统计

在docker里top和在docker外top看内存都是没有变化的&#xff0c;但是用docker stats看mem uasge就一直在涨top命令和docker stats命令采集内存使用的方式不同所致。top命令采集的是当前进程的内存使用情况&#xff0c;而docker stats命令采集的是整个Docker容器的内存使用情况。…...

【IDEA】IDEA使用有道翻译引擎—详细配置步骤

目录 前言 步骤一&#xff1a;下载翻译工具Translate 步骤二&#xff1a;注册登录有道云平台 步骤三&#xff1a;配置有道翻译 前言 2022年10月 谷歌翻译已经不在中国了&#xff0c;所以IDEA配置谷歌翻译会出错。 步骤一&#xff1a;下载翻译工具Translate 打开idea设置set…...

js求解《初级算法》56.最长公共前缀

一、题目描述 编写一个函数来查找字符串数组中的最长公共前缀。如果不存在公共前缀&#xff0c;返回空字符串 ""。 输入&#xff1a;strs ["flower","flow","flight"] 输出&#xff1a;"fl" 输入&#xff1a;strs ["…...

嵌入式Linux(二十四)系统烧写

将uboot&#xff0c;linux kernel&#xff0c;.dtb&#xff0c;rootfs烧写到板子上的EMMC上&#xff0c;避免断网导致不能运行。 1. MfgTool工具介绍 一路解压之后&#xff0c;得到以下两项&#xff1a; ①Profiles文件夹&#xff1a;后续烧写文件放到这个文件夹。  其中关注…...

【ECNU】3496. 贪吃的 xjj 和贪心的 oxx(C++)

目录 题目 输入格式 输出格式 样例 提示 思路 代码 题目 单点时限: 2.0 sec 内存限制: 256 MB oxx 与 xjj 终于到了 Xiamen&#xff0c;他们第一件事就是去吃当地著名的特产椰子饼。 他们共买了 n 盒礼盒&#xff0c;第 i 盒含 ai 块椰子饼。oxx 与 xjj 约定让 oxx …...