【论文复现】基于BERT的语义分析实现
📝个人主页🌹:Eternity._
🌹🌹期待您的关注 🌹🌹
❀ WRN: 宽度残差网络
- 概述
- 语义分类
- 文本分类
- 情感分类
- 实现原理
- 核心逻辑
- pre_deal.py
- train.py
- test_demo.py
- 实现方式&演示效果
- 训练阶段
- 测试阶段
概述
在之前的文章中,我们介绍了BERT模型。BERT作为一种预训练语言模型,它具有很好的兼容性,能够运用在各种下游任务中,本文的主要目的是利用数据集来对BERT进行训练,从而实现一个语义分类的模型。
本文所涉及的所有资源的获取方式:这里
语义分类
语义分类是自然语言处理任务中的一种,包含文本分类、情感分析
文本分类
文本分类是指给定文本a,将文本分类为n个类别中的一个或多个。常见的应用包括文本话题分类,情感分类,具体的分类方向有有二分类,多分类和多标签分类。
文本分类可以采用传统机器学习方法(贝叶斯,svm等)和深度学习方法(fastText,TextCNN等)实现。
举例而言,对于一个对话数据集,我们可以用1、2、3表示他们的话题,如家庭、学校、工作等,而文本分类的目的,则是把这些文本的话题划分到给定的三种类别中。
情感分类
情感分析是自然语言处理中常见的场景,比如商品评价等。通过情感分析,可以挖掘产品在各个维度的优劣。情感分类其实也是一种特殊的文本分类,只是他更聚焦于情感匹配词典。
举例而言,情感分类可以用0/1表示负面评价/正面评价,例子如下:
0,不好的,319房间有故臭味。要求换房说满了,我是3月去的。在路上认识了一个上海人,他说他退房前也住的319,也是一股臭味。而且这个去不掉,特别是晚上,很浓。不知道是厕所的还是窗外的。服务一般,门前有绿皮公交去莫高窟,不过敦煌宾馆也有,下次住敦煌宾馆。再也不住这个酒店了,热水要放半个小时才有。
1,不错的酒店,大堂和餐厅的环境都不错。但由于给我的是一间走廊尽头的房间,所以房型看上去有点奇怪。客厅和卧室是连在一起的,面积偏小。服务还算到位,总的来说,性价比还是不错的。
本文将以情感二分类为例,实现如何利用BERT进行语义分析。
实现原理
首先,基于BERT预训练模型,能将一个文本转换成向量,作为模型的输入。
在BERT预训练模型的基础上,新增一个全连接层,将输入的向量通过训练转化成一个tensor作为输出,其中这个tensor的维度则是需要分类的种类,具体的值表示每个种类的概率。例如:
[0.25,0.75]
指代的是有0.25的概率属于第一类,有0.75的概率属于第二类,因此,理论输出结果是把该文本分为第二类。
核心逻辑
pre_deal.py
import csv
import random
from datasets import load_datasetdef read_file(file_path):csv_reader = csv.reader(open(file_path, encoding='UTF-8'))num = 0data = []for row in csv_reader:if num == 0:num = 1continuecomment_data = [row[1], int(row[0])]if len(comment_data[0]) > 500:text=comment_data[0]sub_texts, start, length = [], 0, len(text)while start < length:piecedata=[text[start: start + 500], comment_data[1]]data.append(piecedata)start += 500else:data.append(comment_data)random.shuffle(data)return data
对输入的csv文件进行处理,其中我们默认csv文件的格式是[label,text],将用于训练的内容读取出来,转化为numpy格式,其中,如果遇到有些文本过长(超过模型的输入),将其截断,分为多个文本段来输入。在最后,会通过shuffle函数进行打乱。
train.py
train.py定义了几个函数,用于训练。
首先是Bertmodel类,定义了基于Bert的训练模型:
class Bertmodel(nn.Module):def __init__(self, output_dim, model_path):super(Bertmodel, self).__init__()# 导入bert模型self.bert = BertModel.from_pretrained(model_path)# 外接全连接层self.layer1 = nn.Linear(768, output_dim)def forward(self, tokens):res = self.bert(**tokens)res = self.layer1(res[1])res = res.softmax(dim=1)return res
该模型由Bert和一个全连接层组成,最后经过softmax激活函数。
其次是一个评估函数,用来计算模型结果的准确性
def evaluate(net, comments_data, labels_data, device, tokenizer):ans = 0 # 输出结果i = 0step = 8 # 每轮一次读取多少条数据tot = len(comments_data)while i <= tot:print(i)comments = comments_data[i: min(i + step, tot)]tokens_X = tokenizer(comments, padding=True, truncation=True, return_tensors='pt').to(device=device)res = net(tokens_X) # 获得到预测结果y = torch.tensor(labels_data[i: min(i + step, tot)]).reshape(-1).to(device=device)ans += (res.argmax(axis=1) == y).sum()i += stepreturn ans / tot
原理就是,将文本转化为tokens,输入给模型,而后利用返回的结果,计算准确性
下面展示了开始训练的主函数,在训练的过程中,进行后向传播,储存checkpoints模型
def training(net, tokenizer, loss, optimizer, train_comments, train_labels, test_comments, test_labels,device, epochs):max_acc = 0.5 # 初始化模型最大精度为0.5for epoch in tqdm(range(epochs)):step = 8i, sum_loss = 0, 0tot=len(train_comments)while i < tot:comments = train_comments[i: min(i + step, tot)]tokens_X = tokenizer(comments, padding=True, truncation=True, return_tensors='pt').to(device=device)res = net(tokens_X)y = torch.tensor(train_labels[i: min(i + step, len(train_comments))]).reshape(-1).to(device=device)optimizer.zero_grad() # 清空梯度l = loss(res, y) # 计算损失l.backward() # 后向传播optimizer.step() # 更新梯度sum_loss += l.detach() # 累加损失i += steptrain_acc = evaluate(net, train_comments, train_labels)test_acc = evaluate(net, test_comments, test_labels)print('\n--epoch', epoch + 1, '\t--loss:', sum_loss / (len(train_comments) / 8), '\t--train_acc:', train_acc,'\t--test_acc', test_acc)# 保存模型参数,并重设最大值if test_acc > max_acc:# 更新历史最大精确度max_acc = test_acc# 保存模型max_acc = test_acctorch.save({'epoch': epoch,'state_dict': net.state_dict(),'optimizer': optimizer.state_dict()}, 'model/checkpoint_net.pth')
训练结果表示如下:
--epoch 0 --train_acc: tensor(0.6525, device='cuda:1') --test_acc tensor(0.6572, device='cuda:1')0%| | 0/20 [00:00<?, ?it/s]5%|▌ | 1/20 [01:48<34:28, 108.88s/it]10%|█ | 2/20 [03:38<32:43, 109.10s/it]15%|█▌ | 3/20 [05:27<30:56, 109.20s/it]20%|██ | 4/20 [07:15<29:02, 108.93s/it]25%|██▌ | 5/20 [09:06<27:23, 109.58s/it]30%|███ | 6/20 [10:55<25:29, 109.26s/it]35%|███▌ | 7/20 [12:44<23:40, 109.28s/it]40%|████ | 8/20 [14:33<21:51, 109.29s/it]45%|████▌ | 9/20 [16:23<20:04, 109.49s/it]50%|█████ | 10/20 [18:13<18:15, 109.59s/it]55%|█████▌ | 11/20 [20:03<16:27, 109.72s/it]60%|██████ | 12/20 [21:52<14:35, 109.45s/it]65%|██████▌ | 13/20 [23:41<12:45, 109.35s/it]70%|███████ | 14/20 [25:30<10:54, 109.14s/it]75%|███████▌ | 15/20 [27:19<09:05, 109.03s/it]80%|████████ | 16/20 [29:07<07:15, 108.84s/it]85%|████████▌ | 17/20 [30:56<05:26, 108.86s/it]90%|█████████ | 18/20 [32:44<03:37, 108.75s/it]95%|█████████▌| 19/20 [34:33<01:48, 108.73s/it]
100%|██████████| 20/20 [36:22<00:00, 108.71s/it]
100%|██████████| 20/20 [36:22<00:00, 109.11s/it]--epoch 1 --loss: tensor(1.2426, device='cuda:1') --train_acc: tensor(0.6759, device='cuda:1') --test_acc tensor(0.6789, device='cuda:1')--epoch 2 --loss: tensor(1.0588, device='cuda:1') --train_acc: tensor(0.8800, device='cuda:1') --test_acc tensor(0.8708, device='cuda:1')--epoch 3 --loss: tensor(0.8543, device='cuda:1') --train_acc: tensor(0.8988, device='cuda:1') --test_acc tensor(0.8887, device='cuda:1')--epoch 4 --loss: tensor(0.8208, device='cuda:1') --train_acc: tensor(0.9111, device='cuda:1') --test_acc tensor(0.8990, device='cuda:1')--epoch 5 --loss: tensor(0.8024, device='cuda:1') --train_acc: tensor(0.9206, device='cuda:1') --test_acc tensor(0.9028, device='cuda:1')--epoch 6 --loss: tensor(0.7882, device='cuda:1') --train_acc: tensor(0.9227, device='cuda:1') --test_acc tensor(0.9024, device='cuda:1')--epoch 7 --loss: tensor(0.7749, device='cuda:1') --train_acc: tensor(0.9288, device='cuda:1') --test_acc tensor(0.9036, device='cuda:1')--epoch 8 --loss: tensor(0.7632, device='cuda:1') --train_acc: tensor(0.9352, device='cuda:1') --test_acc tensor(0.9061, device='cuda:1')--epoch 9 --loss: tensor(0.7524, device='cuda:1') --train_acc: tensor(0.9421, device='cuda:1') --test_acc tensor(0.9090, device='cuda:1')--epoch 10 --loss: tensor(0.7445, device='cuda:1') --train_acc: tensor(0.9443, device='cuda:1') --test_acc tensor(0.9103, device='cuda:1')--epoch 11 --loss: tensor(0.7397, device='cuda:1') --train_acc: tensor(0.9480, device='cuda:1') --test_acc tensor(0.9128, device='cuda:1')--epoch 12 --loss: tensor(0.7321, device='cuda:1') --train_acc: tensor(0.9505, device='cuda:1') --test_acc tensor(0.9123, device='cuda:1')--epoch 13 --loss: tensor(0.7272, device='cuda:1') --train_acc: tensor(0.9533, device='cuda:1') --test_acc tensor(0.9140, device='cuda:1')--epoch 14 --loss: tensor(0.7256, device='cuda:1') --train_acc: tensor(0.9532, device='cuda:1') --test_acc tensor(0.9111, device='cuda:1')--epoch 15 --loss: tensor(0.7186, device='cuda:1') --train_acc: tensor(0.9573, device='cuda:1') --test_acc tensor(0.9123, device='cuda:1')--epoch 16 --loss: tensor(0.7135, device='cuda:1') --train_acc: tensor(0.9592, device='cuda:1') --test_acc tensor(0.9136, device='cuda:1')--epoch 17 --loss: tensor(0.7103, device='cuda:1') --train_acc: tensor(0.9601, device='cuda:1') --test_acc tensor(0.9128, device='cuda:1')--epoch 18 --loss: tensor(0.7091, device='cuda:1') --train_acc: tensor(0.9590, device='cuda:1') --test_acc tensor(0.9086, device='cuda:1')--epoch 19 --loss: tensor(0.7084, device='cuda:1') --train_acc: tensor(0.9626, device='cuda:1') --test_acc tensor(0.9123, device='cuda:1')--epoch 20 --loss: tensor(0.7038, device='cuda:1') --train_acc: tensor(0.9628, device='cuda:1') --test_acc tensor(0.9107, device='cuda:1')
最终训练结果,在训练集上达到了96.28%的准确率,在测试集上达到了91.07%的准确率
test_demo.py
这个函数提供了一个调用我们储存的checkpoint模型来进行预测的方式,将input转化为berttokens,而后输入给模型,返回输出结果。
input_text=['这里环境很好,风光美丽,下次还会再来的。']
Bert_model_path = 'xxxx'
output_path='xxxx'
device = torch.device('cpu')
checkpoint = torch.load(output_path,map_location='cpu')model = Bertmodel(output_dim=2,model_path=Bert_model_path)
model.load_state_dict(checkpoint,False)
# print(model)
tokenizer = BertTokenizer.from_pretrained(Bert_model_path,model_max_length=512)tokens_X = tokenizer(input_text, padding=True, truncation=True, return_tensors='pt').to(device='cpu')
model.eval()
output=model(tokens_X)
print(output)
out = torch.unsqueeze(output.argmax(dim=1), dim=1)
result = out.numpy()
print(result)
if result[0][0]==1:print("positive")
else:print("negative")
实现方式&演示效果
训练阶段
首先找到能够拿来训练的数据,运行pre_deal.py进行预处理,而后可以在main.py修改模型的相关参数,运行main.py开始训练。
这个过程,可能会收到硬件条件的影响,推荐使用cuda进行训练。如果实在训练不了,可以直接调用附件中对应的训练好的模型来进行预测。
测试阶段
运行test_demo.py,测试输入文本的分类结果
输入
input_text=['这里环境很好,风光美丽,下次还会再来的。']
输出
tensor([[0.3191, 0.6809]], grad_fn=<SoftmaxBackward0>)
[[1]]
positive
得出,这句话的情感分类是positive(正面)
编程未来,从这里启航!解锁无限创意,让每一行代码都成为你通往成功的阶梯,帮助更多人欣赏与学习!
更多内容详见:这里
相关文章:

【论文复现】基于BERT的语义分析实现
📝个人主页🌹:Eternity._ 🌹🌹期待您的关注 🌹🌹 ❀ WRN: 宽度残差网络 概述语义分类文本分类情感分类 实现原理 核心逻辑pre_deal.pytrain.pytest_demo.py 实现方式&演示效果训练阶段测试阶…...
CTF-RE: STL逆向 [NewStarCTF 2023 公开赛道 STL] WP
多看看STL题就会了,很简单 int __fastcall main(int argc, const char **argv, const char **envp) {__int64 v3; // rbx__int64 v4; // raxchar v5; // bl_BYTE *v6; // rax_QWORD *v7; // rax__int64 v8; // rax__int64 v9; // raxint i; // [rsp0h] [rbp-250h]int j; // [r…...
实习冲刺第三十六天
46.全排列 给定一个不含重复数字的数组 nums ,返回其 所有可能的全排列 。你可以 按任意顺序 返回答案。 示例 1: 输入:nums [1,2,3] 输出:[[1,2,3],[1,3,2],[2,1,3],[2,3,1],[3,1,2],[3,2,1]]示例 2: 输入&#…...

【Zemax光学设计实训三】---激光缩束镜的设计优化
前言与目录 技术设计要求: 设计一个激光扩束镜,使用的波长为1064nm,输入光束直径为10mm,输出光束的直径为2mm,且输入光束和输出光束平行(即平行光入射,平行光出射)。要求只使用两片…...

TCP/IP协议簇自学笔记
摘抄于大学期间记录在QQ空间的一篇自学笔记,当前清理空间,本来想直接删除掉的,但是感觉有些舍不得,因此先搬移过来。 曾经,我只知道socket函数能进行网络间数据的通信,知道tcp/ip协议也是用来进行网络数据…...

Spring Boot教程之十一:获取Request 请求 和 Put请求
如何在 Spring Boot 中获取Request Body? Java 语言是所有编程语言中最流行的语言之一。使用 Java 编程语言有几个优点,无论是出于安全目的还是构建大型分发项目。使用 Java 的优点之一是 Java 试图借助类、继承、多态等概念将语言中的每个概念与现实世…...

计算机网络(二)
ip地址:11010010:01011110:00100100:00010100 子网掩码:11111111:11111111:11111111:11000000 and :11010010:01011110:00100100:00000000 210.94.36.0的下一站为R1 因为255为11111111 192为ÿ…...

如何在Python中进行数学建模?
数学建模是数据科学中使用的强大工具,通过数学方程和算法来表示真实世界的系统和现象。Python拥有丰富的库生态系统,为开发和实现数学模型提供了一个很好的平台。本文将指导您完成Python中的数学建模过程,重点关注数据科学中的应用。 数学建…...
JavaSE——类与对象(5)
一、抽象类 1.1为什么需要抽象类 父类的某些方法,不确定怎么实现,也不需要实现。 class Animal{public String name;public Animal(String name){this.name name;}public void eat()//这里实现了也没有意义{System.out.println("这是一个动物&am…...

Istio笔记01--快速体验Istio
Istio笔记01--快速体验Istio 介绍部署与测试部署k8s安装istio测试istio 注意事项说明 介绍 Istio是当前最热门的服务网格产品,已经被广泛应用于各个云厂商和IT互联网公司。企业可以基于Istio轻松构建服务网格,在接入过程中应用代码无需更改,…...
面试小札:Java如何实现并发编程
多线程基础 继承Thread类 定义一个类继承自 Thread 类,重写 run 方法。在 run 方法中编写线程要执行的任务逻辑。例如: java class MyThread extends Thread { Override public void run() { System.out.println("线程执行的任务…...
java-a+b 开启java语法学习
代码 (ab) import java.util.Scanner; //导入 java.util包中的Scanner 类,允许读取键盘输入数据public class Main { // 创建一个公共类 Mainpublic static void main(String[] args) {//程序入口点,main方法Scanner scanner new Scanner(…...
RNN模型文本预处理--数据增强方法
数据增强方法 数据增强是自然语言处理(NLP)中常用的一种技术,通过生成新的训练样本来扩充数据集,从而提高模型的泛化能力和性能。回译数据增强法是一种常见的数据增强方法,特别适用于文本数据。 回译数据增强法 定义…...
maven 中<packaging>pom</packaging>配置使用
在 Maven 项目的 pom.xml 文件中, 元素用于指定项目的打包类型。默认情况下,如果 元素没有被显式定义,Maven 会假设其值为 jar。但是,当您设置 pom 时,这意味着该项目是一个 POM(Project Object Model&…...

【Python中while循环】
一、深拷贝、浅拷贝 1、需求 1)拷贝原列表产生一个新列表 2)想让两个列表完全独立开(针对改操作,读的操作不改变) 要满足上述的条件,只能使用深拷贝 2、如何拷贝列表 1)直接赋值 # 定义一个…...
【深度学习】服务器常见命令
1、虚拟环境的安装位置 先进入虚拟环境 which python2、升序查看文件内容 ls -ltr3、查看服务器主机空间使用情况 df -hdf -h .4、查看本地空间使用情况 du -sh ./*du -sh * | sort -nr5、查找并删除进程 # 查找 ps aux# 删除 kill -KILL pid6、查看服务器配置 lscpuuna…...
技术分析模板
文章目录 概要整体架构流程技术名词解释技术细节小结 概要 提示:这里可以添加技术概要 例如: openAI 的 GPT 大模型的发展历程。 整体架构流程 提示:这里可以添加技术整体架构 例如: 在语言模型中,编码器和解码器…...

python:文件操作
一、文件路径 在Windows系统中,每个磁盘都有自己的根目录,用分区名加反斜杠来表示。我们定位文件的位置有两种方法,一种是绝对路径,另一种是相对路径。绝对路径是从根目录出发的路径,路径中的每个路径之间用反斜杠来分…...
Nginx和Apache有什么异同?
Nginx和Apache都是广泛使用的Web服务器软件,它们各自具有独特的特点和优势,适用于不同的应用场景。以下是关于Nginx和Apache的不同、相同以及使用区别的详细分析: 一、不同点 资源占用与并发处理能力: Nginx使用更少的内存和CPU资…...

泰州榉之乡全托机构探讨:自闭症孩子精细动作训练之法
当发现自闭症孩子精细动作落后时,家长们往往会感到担忧和困惑。那么,自闭症孩子精细动作落后该如何训练呢?今天,泰州榉之乡全托机构就来为大家详细解答。 榉之乡大龄自闭症托养机构在江苏、广东、江西等地都有分校,一直…...

树莓派超全系列教程文档--(61)树莓派摄像头高级使用方法
树莓派摄像头高级使用方法 配置通过调谐文件来调整相机行为 使用多个摄像头安装 libcam 和 rpicam-apps依赖关系开发包 文章来源: http://raspberry.dns8844.cn/documentation 原文网址 配置 大多数用例自动工作,无需更改相机配置。但是,一…...
uni-app学习笔记二十二---使用vite.config.js全局导入常用依赖
在前面的练习中,每个页面需要使用ref,onShow等生命周期钩子函数时都需要像下面这样导入 import {onMounted, ref} from "vue" 如果不想每个页面都导入,需要使用node.js命令npm安装unplugin-auto-import npm install unplugin-au…...
什么是EULA和DPA
文章目录 EULA(End User License Agreement)DPA(Data Protection Agreement)一、定义与背景二、核心内容三、法律效力与责任四、实际应用与意义 EULA(End User License Agreement) 定义: EULA即…...
三体问题详解
从物理学角度,三体问题之所以不稳定,是因为三个天体在万有引力作用下相互作用,形成一个非线性耦合系统。我们可以从牛顿经典力学出发,列出具体的运动方程,并说明为何这个系统本质上是混沌的,无法得到一般解…...
爬虫基础学习day2
# 爬虫设计领域 工商:企查查、天眼查短视频:抖音、快手、西瓜 ---> 飞瓜电商:京东、淘宝、聚美优品、亚马逊 ---> 分析店铺经营决策标题、排名航空:抓取所有航空公司价格 ---> 去哪儿自媒体:采集自媒体数据进…...

如何在最短时间内提升打ctf(web)的水平?
刚刚刷完2遍 bugku 的 web 题,前来答题。 每个人对刷题理解是不同,有的人是看了writeup就等于刷了,有的人是收藏了writeup就等于刷了,有的人是跟着writeup做了一遍就等于刷了,还有的人是独立思考做了一遍就等于刷了。…...

基于Java+VUE+MariaDB实现(Web)仿小米商城
仿小米商城 环境安装 nodejs maven JDK11 运行 mvn clean install -DskipTestscd adminmvn spring-boot:runcd ../webmvn spring-boot:runcd ../xiaomi-store-admin-vuenpm installnpm run servecd ../xiaomi-store-vuenpm installnpm run serve 注意:运行前…...

MacOS下Homebrew国内镜像加速指南(2025最新国内镜像加速)
macos brew国内镜像加速方法 brew install 加速formula.jws.json下载慢加速 🍺 最新版brew安装慢到怀疑人生?别怕,教你轻松起飞! 最近Homebrew更新至最新版,每次执行 brew 命令时都会自动从官方地址 https://formulae.…...

mac:大模型系列测试
0 MAC 前几天经过学生优惠以及国补17K入手了mac studio,然后这两天亲自测试其模型行运用能力如何,是否支持微调、推理速度等能力。下面进入正文。 1 mac 与 unsloth 按照下面的进行安装以及测试,是可以跑通文章里面的代码。训练速度也是很快的。 注意…...

DeepSeek源码深度解析 × 华为仓颉语言编程精粹——从MoE架构到全场景开发生态
前言 在人工智能技术飞速发展的今天,深度学习与大模型技术已成为推动行业变革的核心驱动力,而高效、灵活的开发工具与编程语言则为技术创新提供了重要支撑。本书以两大前沿技术领域为核心,系统性地呈现了两部深度技术著作的精华:…...