【NLP相关】基于现有的预训练模型使用领域语料二次预训练

【NLP相关】基于现有的预训练模型使用领域语料二次预训练
在自然语言处理领域,预训练模型已经成为了最为热门和有效的技术之一。预训练模型通过在大规模文本语料库上进行训练,可以学习到通用的语言模型,然后可以在不同的任务上进行微调。但是,预训练模型在领域特定任务上的表现可能不够好,因为它们是在通用语言语料库上进行训练的。为了提高在特定领域的任务中的性能,我们可以使用领域语料库对预训练模型进行二次预训练。
本篇博客将介绍如何基于现有的预训练模型使用领域语料二次预训练。我们将以 PyTorch 和 Transformers 库为基础,以医学文本分类任务为例,来详细说明二次预训练的过程。
1. 模型介绍
在本篇博客中,我们使用的预训练模型是 BERT(Bidirectional Encoder Representations from Transformers)。BERT 是一种基于 Transformer 的预训练模型,由 Google 团队开发。它在多个自然语言处理任务上取得了最先进的结果,例如文本分类、命名实体识别和问答系统等。
BERT 模型是一种双向的 Transformer 模型,能够有效地处理自然语言序列。它将文本输入嵌入到向量空间中,并在此基础上进行自监督训练,以学习通用的语言表示。在预训练完成后,BERT 模型可以进行微调,以适应不同的自然语言处理任务。
2. 代码实现
2.1 数据预处理
在开始二次预训练之前,我们需要准备领域特定的语料库。在这里,我们使用的是医学文本分类数据集,其中包含了一些医学文章的标题和摘要,并且每个文本都被标记为一个预定义的类别。
首先,我们需要将原始文本数据拆分为单个句子,并将其标记化处理。我们可以使用 Hugging Face 的 tokenizer 来完成这个任务。
from transformers import BertTokenizer# 加载 BERT tokenizer
tokenizer = BertTokenizer.from_pretrained("bert-base-uncased")class MedicalDataset(Dataset):def __init__(self, tokens, max_length=128):self.tokens = tokensself.max_length = max_lengthdef __len__(self):return len(self.tokens)def __getitem__(self, idx):# 获取句子对tokens = self.tokens[idx]# 将句子对拼接成一个序列,并将其标记化处理input_ids = tokenizer.encode(tokens[0], tokens[1],add_special_tokens=True, max_length=self.max_length,truncation_strategy='longest_first')attention_mask = [1] * len(input_ids)# 填充序列长度padding = [0] * (self.max_length - len(input_ids))input_ids += paddingattention_mask += padding# 返回 input_ids 和 attention_maskreturn torch.LongTensor(input_ids), torch.LongTensor(attention_mask)
2.2 二次预训练
在数据预处理之后,我们可以开始进行二次预训练了。在这里,我们将使用 Hugging Face 的 Transformers 库,以及 PyTorch 框架来实现二次预训练。
首先,我们需要加载预训练的 BERT 模型。在这里,我们使用的是 bert-base-uncased 模型,它是一个基于英文的预训练模型。我们还需要定义一些训练参数,例如学习率和批大小等。
from transformers import BertForPreTraining, AdamW
from torch.utils.data import DataLoader# 加载预训练的 BERT 模型
model = BertForPreTraining.from_pretrained('bert-base-uncased')# 定义训练参数
epochs = 3
batch_size = 16
learning_rate = 2e-5
device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')# 将模型移动到 GPU 上
model.to(device)
接下来,我们需要加载领域特定的语料库,并将其转换为 PyTorch 数据集。在这里,我们使用的是 PyTorch 中的 Dataset 类。我们还需要将数据集加载到 PyTorch 的数据加载器中,以便进行训练。
# 加载领域特定的语料库
with open('medical_data.txt') as f:sentences = f.readlines()# 将语料库转换为 PyTorch 数据集
dataset = MedicalDataset(sentences)# 将数据集加载到 PyTorch 的数据加载器中
loader = DataLoader(dataset, batch_size=batch_size, shuffle=True)
在准备好数据集之后,我们可以开始训练模型了。我们将使用 AdamW 优化器和交叉熵损失函数来训练模型。在每个 epoch 完成之后,我们会对模型进行一次测试,并计算准确率和损失函数值。最后,我们将保存训练好的模型。
# 定义优化器和损失函数
optimizer = AdamW(model.parameters(), lr=learning_rate)
criterion = torch.nn.CrossEntropyLoss()# 训练模型
for epoch in range(epochs):total_loss = 0total_correct = 0total_samples = 0# 遍历数据集for i, batch in enumerate(loader):# 将输入数据和标签移动到 GPU 上input_ids, attention_mask = batchinput_ids = input_ids.to(device)attention_mask = attention_mask.to(device)# 将模型设置为训练模式model.train()# 计算模型的输出outputs = model(input_ids, attention_mask=attention_mask)# 计算损失函数值loss = criterion(outputs.logits.view(-1, 2), outputs.labels.view(-1))# 清除之前的梯度optimizer.zero_grad()# 反向传播和优化loss.backward()optimizer.step()# 统计训练信息total_loss += loss.item()total_samples += input_ids.size(0)total_correct += torch.sum(torch.argmax(outputs.logits, dim=-1) == outputs.labels.view(-1)).item()# 输出训练信息if (i + 1) % 100 == 0:print('Epoch [%d/%d], Step [%d/%d], Loss: %.4f, Accuracy: %.2f%%'% (epoch + 1, epochs, i + 1, len(loader),total_loss / total_samples, total_correct / total_samples * 100))# 在每个 epoch 完成之后进行一次测试
with torch.no_grad():total_loss = 0total_correct = 0total_samples = 0# 将模型设置为评估模式model.eval()# 遍历测试数据集for i, batch in enumerate(test_loader):# 将输入数据和标签移动到 GPU 上input_ids, attention_mask = batchinput_ids = input_ids.to(device)attention_mask = attention_mask.to(device)# 计算模型的输出outputs = model(input_ids, attention_mask=attention_mask)# 计算损失函数值loss = criterion(outputs.logits.view(-1, 2), outputs.labels.view(-1))# 统计测试信息total_loss += loss.item()total_samples += input_ids.size(0)total_correct += torch.sum(torch.argmax(outputs.logits, dim=-1) == outputs.labels.view(-1)).item()# 输出测试信息print('Epoch [%d/%d], Test Loss: %.4f, Test Accuracy: %.2f%%'% (epoch + 1, epochs, total_loss / total_samples, total_correct / total_samples * 100))#保存训练好的模型torch.save(model.state_dict(), 'medical_bert.pth')
3. 案例解析
假设我们要对医学领域中的文本进行二次预训练。我们可以使用已经预训练好的 BERT 模型,并使用医学领域的语料库进行二次预训练。
首先,我们需要将医学领域的语料库进行预处理。在这里,我们可以使用 NLTK 库来进行分词和词形还原等操作。我们还可以将语料库中的每个句子转换为 BERT 输入格式。
import nltk
from nltk.tokenize import word_tokenize
from nltk.stem import WordNetLemmatizer
from transformers import BertTokenizer# 加载 BERT 分词器
tokenizer = BertTokenizer.from_pretrained('bert-base-uncased')# 加载 NLTK 分词器和词形还原
nltk.download('punkt')
nltk.download('wordnet')
lemmatizer = WordNetLemmatizer()加载医学领域的语料库
with open('medical_corpus.txt', 'r') as f:corpus = f.read()#对每个句子进行分词、词形还原和转换为 BERT 输入格式
sentences = []
for sentence in nltk.sent_tokenize(corpus):words = nltk.word_tokenize(sentence)words = [lemmatizer.lemmatize(word) for word in words]words = [word.lower() for word in words]tokens = tokenizer.encode_plus(words,add_special_tokens=True,max_length=512,padding='max_length',truncation=True)sentences.append((tokens['input_ids'], tokens['attention_mask']))
接下来,我们需要使用这些句子对 BERT 模型进行二次预训练。为此,我们需要定义一个新的数据加载器,将这些句子传递给模型进行训练。
from torch.utils.data import Dataset, DataLoaderclass MedicalDataset(Dataset):def __init__(self, sentences):self.sentences = sentencesdef __len__(self):return len(self.sentences)def __getitem__(self, idx):return self.sentences[idx]# 定义数据加载器
loader = DataLoader(MedicalDataset(sentences),batch_size=16,shuffle=True)
现在,我们可以开始对 BERT 模型进行二次预训练了。我们可以使用与之前相同的训练代码。
# 定义训练函数
def train(model, loader, optimizer, device):model.train()for batch in loader:input_ids = batch[0].to(device)attention_mask = batch[1].to(device)optimizer.zero_grad()loss, _ = model(input_ids=input_ids,attention_mask=attention_mask,output_hidden_states=True)[:2]loss.backward()optimizer.step()# 加载预训练的 BERT 模型
model = BertForMaskedLM.from_pretrained('bert-base-uncased')# 设置设备
device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
model.to(device)# 定义优化器
optimizer = AdamW(model.parameters(), lr=5e-5)# 进行二次预训练
for epoch in range(num_epochs):train(model, train_loader, optimizer, device)# 每个 epoch 结束后测试模型的性能perplexity = evaluate(model, test_loader, device)print(f'Epoch {epoch+1}, perplexity: {perplexity:.3f}')# 保存模型model_path = f'model_epoch{epoch+1}.pt'torch.save(model.state_dict(), model_path)
这里定义了一个 train 函数来训练模型。这个函数接收一个模型、一个数据加载器、一个优化器和一个设备作为输入。它会将模型设为训练模式,并且在每个批次上运行前向传播、计算损失、计算梯度和更新参数。
接下来,我们加载预训练的 BERT 模型,并将其移动到所选设备上。我们使用 AdamW 优化器,并将学习率设置为 5e-5。
最后,我们使用一个简单的 for 循环来进行二次预训练。在每个 epoch 结束时,我们会在测试集上评估模型,并打印出 perplexity 指标。我们还会将模型保存在磁盘上,以便以后进行检索。
相关文章:
【NLP相关】基于现有的预训练模型使用领域语料二次预训练
❤️觉得内容不错的话,欢迎点赞收藏加关注😊😊😊,后续会继续输入更多优质内容❤️👉有问题欢迎大家加关注私戳或者评论(包括但不限于NLP算法相关,linux学习相关,读研读博…...
使用git进行项目管理--git使用及其常用命令
使用git进行项目管理 文章目录 使用git进行项目管理git使用1.添加用户名字2.添加用户邮箱3.git初始化4.add5.commit6.添加到gitee仓库7.推送到gitee8.切换版本git常用命令git add把指定的文件添加到暂存区中添加所有修改、已删除的文件到暂存区中添加所有修改、已删除、新增的文…...
Mybatis_CRUD使用
目录1 Mybatis简介环境说明:预备知识:1.1 定义1.2 持久化为什么需要持久化服务呢?1.3 持久层1.4 为什么需要Mybatis2 依赖配置3 CRUDnamespaceselect (查询用户数据)※传值方式:于方法中传值使用Map传值insert (插入用…...
JVM的过程内分析和过程间分析有什么区别?
问: 目前所有常见的Java虚拟机对过程间分析的支持都相 当有限,要么借助大规模的方法内联来打通方法间的隔阂,以过程内分析(Intra-Procedural Analysis, 只考虑过程内部语句,不考虑过程调用的分析ÿ…...
LearnDash测验报告如何帮助改进您的课程
某一个场景。Pennywell 大学有一门课程“Introduction to Linear Algebra”。上学期进行了两次测验。20% 的学生在第一次测验中不及格,而 80% 在第二次测验中不及格。在进一步评估中,观察到第一次测验不及格的学生在第二次测验中也不及格。在第二次测验中…...
如何通过Java将Word转换为PDF
Word是我们日常编辑文档内容时十分常用的一种文档格式。但相比之下,PDF文档的格式、布局更为固定,不易被更改。在保存或传输较为重要的文档内容时,PDF文档格式也时很多人的不二选择。很多时候我们都会遇到需要将Word转换为PDF的情况。下面我就…...
DOM型XSS
DOM型XSSDOM是什么DOM型XSSDOM型XSS实操DOM是什么 DOM就是Document。 文档是由节点构成的集合,在DOM里存在许多不同类型的节点,主要有:元素节点、文本节点,属性节点。 元素节点:好比< body >< p >< h …...
04-项目立项:项目方案、可行性分析、产品规划、立项评审
文章目录4.1 项目方案立项阶段4.2 可行性分析4.3 产品规划4.4 立项评审4.4.1 立项说明书的主要内容4.4.2 立项评审流程章节总结4.1 项目方案 学习目标: 能够输出产品项目方案 项目开发设计流程的主要阶段: 立项阶段 → 设计阶段 → 开发阶段 → 测试阶…...
数据分享|NPP VIIRS夜间灯光数据(2012-2020逐月)
2011年10月美国的“索米”国家极轨卫星伙伴卫星(Suomi National Polar-orbiting Partnership or Suomi NPP)发射,它搭载的VIIRS传感器上有一个称为DNB(Day Night Band)的波段能够在500米分辨率(比原来的OLS提高6倍)的尺度上对地表开展每天覆盖全球一次的高灵敏度(比OLS提…...
网络概论笔记
概论 网络研究的是节点和边 移动互联到物联网时代,只有有互联网,网络就不会落伍 协议:对等层面的实体固定的通信规则 协议包括:语法,语义,格式,次序,动作 网络是任意连接的 服务…...
软工2023个人作业二——软件案例分析
项目内容这个作业属于哪个课程2023年北航敏捷软件工程这个作业的要求在哪里个人作业-软件案例分析我在这个课程的目标是学习并掌握现代软件开发和项目管理技术,体验敏捷开发工作流程这个作业在哪个具体方面帮助我实现目标从软件工程角度分析比较我们所熟悉的软件&am…...
python数据分析表格文档Excel数据分析器统计源码
wx供重浩:创享软件 对话框发送:python表格 获取完整源码源文件说明文档可执行文件等 在PyCharm中运行《Excel数据分析师》即可进入如图1所示的系统主界面。在该界面中,通过顶部的工具栏可以选择所要进行的操作。 具体的操作步骤如下ÿ…...
Istio Sidecar启动顺序 - 导致的应用容器网络不通
目录一、问题二、Istio 1.7及其之后版本的解决方案2.1 方式1:安装Istio时全局设置2.2 方式2:在应用Deployment通过annotation设置2.3 holdApplicationUntilProxyStarts启用效果三、Istio 1.7之前的解决方案一、问题 线上应用集成了Spring Cloud K8S Con…...
3696. 构造有向无环图
Powered by:NEFU AB-IN Link 文章目录3696. 构造有向无环图题意思路代码3696. 构造有向无环图 题意 Codeforces Round 656 (Div. 3) E 给定一个由 n个点和 m条边构成的图。 不保证给定的图是连通的。 图中的一部分边的方向已经确定,你不能改变它们的方向。 剩下的边…...
RuoYi-Flowable-Plus(代码生成)
RuoYi-Flowable-Plus搭建 若依所有扩展项目的代码生成功能都是一样的,RuoYi-Flowable-Plus为例来演示。 模块创建 1.创建新模块ruoyi-student2.编辑RuoYi-Flowable-Plus\pom.xml <dependency><groupId>com.ruoyi</groupId><artifactId>ruoy…...
训练CV模型常用的方法与技巧
最近参加一个CV比赛,看到有参赛者分享了自己训练图像识别模型时常用到的小技巧,故对其进行记录、整理,方便未来继续学习。整理了很多,它们不一定每次有用,但请记在心中,说不定未来某个任务它们就发挥了作用…...
[Java·算法·中等]LeetCode22. 括号生成
每天一题,防止痴呆题目示例分析思路1题解1分析思路2题解2分析思路3题解3👉️ 力扣原文 题目 数字 n 代表生成括号的对数,请你设计一个函数,用于能够生成所有可能的并且 有效的 括号组合。 示例 输入:n 3 输出&…...
Git项目合并实践
Git项目合并实践 一、前言 环境 操作系统:Windows 10 专业版 代码托管平台:Gitee 场景 同一个项目,在某一个时间点,被另外一个团队拷贝和修改,并且代码不在同一个仓库,最后需要合并项目 不是同一个项…...
C++实战md5、base64算法实现(附源码)
C++常用功能源码系列 文章目录 C++常用功能源码系列前言一、常用加密算法1. md5是什么二、源码1. md52. base64、decode总结前言 本文是C/C++常用功能代码封装专栏的导航贴。部分来源于实战项目中的部分功能提炼,希望能够达到你在自己的项目中拿来就用的效果,这样更好的服务…...
P6专题:P6 EPPM和PPM基本概念
目录 引言 Oracles Primavera P6 Enterprise Project Portfolio Management(P6 EPPM) Oracles Primavera P6 Professional Project Management 引言 Oracle Primavera系列软件专注于项目密集型企业,其整个项目生命周期内所有项目的组合管…...
web vue 项目 Docker化部署
Web 项目 Docker 化部署详细教程 目录 Web 项目 Docker 化部署概述Dockerfile 详解 构建阶段生产阶段 构建和运行 Docker 镜像 1. Web 项目 Docker 化部署概述 Docker 化部署的主要步骤分为以下几个阶段: 构建阶段(Build Stage):…...
阿里云ACP云计算备考笔记 (5)——弹性伸缩
目录 第一章 概述 第二章 弹性伸缩简介 1、弹性伸缩 2、垂直伸缩 3、优势 4、应用场景 ① 无规律的业务量波动 ② 有规律的业务量波动 ③ 无明显业务量波动 ④ 混合型业务 ⑤ 消息通知 ⑥ 生命周期挂钩 ⑦ 自定义方式 ⑧ 滚的升级 5、使用限制 第三章 主要定义 …...
质量体系的重要
质量体系是为确保产品、服务或过程质量满足规定要求,由相互关联的要素构成的有机整体。其核心内容可归纳为以下五个方面: 🏛️ 一、组织架构与职责 质量体系明确组织内各部门、岗位的职责与权限,形成层级清晰的管理网络…...
C++中string流知识详解和示例
一、概览与类体系 C 提供三种基于内存字符串的流,定义在 <sstream> 中: std::istringstream:输入流,从已有字符串中读取并解析。std::ostringstream:输出流,向内部缓冲区写入内容,最终取…...
自然语言处理——Transformer
自然语言处理——Transformer 自注意力机制多头注意力机制Transformer 虽然循环神经网络可以对具有序列特性的数据非常有效,它能挖掘数据中的时序信息以及语义信息,但是它有一个很大的缺陷——很难并行化。 我们可以考虑用CNN来替代RNN,但是…...
深入浅出深度学习基础:从感知机到全连接神经网络的核心原理与应用
文章目录 前言一、感知机 (Perceptron)1.1 基础介绍1.1.1 感知机是什么?1.1.2 感知机的工作原理 1.2 感知机的简单应用:基本逻辑门1.2.1 逻辑与 (Logic AND)1.2.2 逻辑或 (Logic OR)1.2.3 逻辑与非 (Logic NAND) 1.3 感知机的实现1.3.1 简单实现 (基于阈…...
Python Einops库:深度学习中的张量操作革命
Einops(爱因斯坦操作库)就像给张量操作戴上了一副"语义眼镜"——让你用人类能理解的方式告诉计算机如何操作多维数组。这个基于爱因斯坦求和约定的库,用类似自然语言的表达式替代了晦涩的API调用,彻底改变了深度学习工程…...
脑机新手指南(七):OpenBCI_GUI:从环境搭建到数据可视化(上)
一、OpenBCI_GUI 项目概述 (一)项目背景与目标 OpenBCI 是一个开源的脑电信号采集硬件平台,其配套的 OpenBCI_GUI 则是专为该硬件设计的图形化界面工具。对于研究人员、开发者和学生而言,首次接触 OpenBCI 设备时,往…...
深入理解Optional:处理空指针异常
1. 使用Optional处理可能为空的集合 在Java开发中,集合判空是一个常见但容易出错的场景。传统方式虽然可行,但存在一些潜在问题: // 传统判空方式 if (!CollectionUtils.isEmpty(userInfoList)) {for (UserInfo userInfo : userInfoList) {…...
书籍“之“字形打印矩阵(8)0609
题目 给定一个矩阵matrix,按照"之"字形的方式打印这个矩阵,例如: 1 2 3 4 5 6 7 8 9 10 11 12 ”之“字形打印的结果为:1,…...
