人名分类器(nlp)
# coding: utf-8
import osos.environ['KMP_DUPLICATE_LIB_OK'] = 'True'# 导入torch工具
import jsonimport torch
# 导入nn准备构建模型
import torch.nn as nn
import torch.nn.functional as F
import torch.optim as optim
# 导入torch的数据源 数据迭代器工具包
from torch.utils.data import Dataset, DataLoader
# 用于获得常见字母及字符规范化
import string
# 导入时间工具包
import time
# 引入制图工具包
import matplotlib.pyplot as plt
# 从io中导入文件打开方法
from io import open# 1 获取常用的字符 标点,把每个char字符作为一个token,用onehot编码表示token
# 因此我们的词表就是 char表 (字符表) 57个char
all_letters = string.ascii_letters + " ,.;'"
print(all_letters)n_letter = len(all_letters) # 词表的大小
print('字符表的长度:', n_letter)# 2 获取国家的类别种数
# 国家名 种类数
categorys = ['Italian', 'English', 'Arabic', 'Spanish', 'Scottish', 'Irish', 'Chinese', 'Vietnamese', 'Japanese','French', 'Greek', 'Dutch', 'Korean', 'Polish', 'Portuguese', 'Russian', 'Czech', 'German']
# 国家名 个数,就是模型的 (linear输出维度) 分类数
categorynum = len(categorys)
print('categorys--->', categorys)# 3 读取数据
def read_data(filename):# 3.1 初始化空列表两个my_list_x, my_list_y = [], []# 3.2 读取文件内容with open(filename, 'r', encoding='utf-8') as fr:for line in fr.readlines():# 异常点判断:改行长度<=5,说明这是异常样本,直接跳到下一行if len(line) <= 5:continuex, y = line.strip().split('\t')my_list_x.append(x)my_list_y.append(y)# 3.3 返回两个列表return my_list_x, my_list_y# 4 构建数据集
class NameClsDataset(Dataset):def __init__(self, mylist_x, mylist_y):self.mylist_x = mylist_xself.mylist_y = mylist_ydef __len__(self):return len(self.mylist_x)def __getitem__(self, item):# 01 item 异常值出处理index = min(max(item, 0), len(self.mylist_x) - 1)# 02 根据idx拿到人名 国家名x = self.mylist_x[index]y = self.mylist_y[index]# 03 完成onehottensor_x = torch.zeros(len(x), n_letter)for idx, letter in enumerate(x):tensor_x[idx][all_letters.find(letter)] = 1# 04 获得标签tensor_y = torch.tensor(categorys.index(y), dtype=torch.long)return tensor_x, tensor_y# 5 构建dataloader
def get_dataloader():filename = './data/name_classfication.txt'my_list_x, my_list_y = read_data(filename)mydataset = NameClsDataset(my_list_x, my_list_y)my_dataloader = DataLoader(mydataset,batch_size=1,shuffle=True, # 打乱顺序# drop_last=True, # 是否丢弃最后那个不足一个batch_size的数据组# collate_fn=collate_fn, # 处理一个batch的数据为整齐的维度)x, y = next(iter(my_dataloader))# print(x)# print(x.shape)# print(y)return my_dataloader# 6 创建rnn模型
class MyRNN(nn.Module):def __init__(self, input_size, hidden_size, output_size, num_layers=1):super().__init__()self.input_size = input_sizeself.hidden_size = hidden_sizeself.output_size = output_sizeself.num_layers = num_layersself.rnn = nn.RNN(self.input_size, self.hidden_size,self.num_layers, batch_first=True)# self.linear = nn.Linear(self.hidden_size, self.hidden_size)self.linear = nn.Linear(self.hidden_size, self.output_size)self.softmax = nn.LogSoftmax(dim=-1)def forward(self, input):# input.shape = (1, 9, 57)# hidden.shape = (1, 1, 128)# rnn_output.shape = (1, 9, 128)# rnn_hn.shape = (1, 1, 128)# rnn_output, _ = self.rnn(input)rnn_output, rnn_hn = self.rnn(input)# temp.shape = (1, 128)# temp = rnn_output[0][-1].unsqueeze(0)temp = rnn_hn[0]# output.shape=(1,18)# self.softmax(output) (2, 18)output = self.linear(temp) # 可以接受三维数据return self.softmax(output), rnn_hn# 7 测试RNN
def ceshiRNN():# 1 拿到数据my_dataloader = get_dataloader()# 2 实例化模型input_size = n_letter # 字符表的大小 (词表的大小)hidden_size = 128 # 超参数 768,rnn输出维度output_size = len(categorys) # 18,分类总数my_rnn = MyRNN(input_size, hidden_size, output_size)# 3 将数据送入到模型x, y = next(iter(my_dataloader))output, hn = my_rnn(x) # output.shape = (1, 18)print(output.shape)print(hn.shape)# 8 训练RNN
def train_my_rnn():epochs = 1my_lr = 1e-3# 1 读取数据my_list_x, my_list_y = read_data('./data/name_classfication.txt')# 2 定义datasetmyDataset = NameClsDataset(my_list_x, my_list_y)# 3 实例化dataloadermy_dataloader = DataLoader(myDataset, batch_size=1, shuffle=True)# 4 实例化RNN模型input_size = 57hidden_size = 128output_size = 18my_rnn = MyRNN(input_size, hidden_size, output_size)# 5 损失函数my_crossentropy = nn.NLLLoss()# 6 优化器my_optimizer = optim.Adam(my_rnn.parameters(), lr=my_lr)# 7 日志start_time = time.time()total_iter_num = 0 # 已经训练好的样本数total_loss = 0 # 总的losstotal_loss_list = [] # 每隔多少步存储loss-avgtotal_acc_num = 0total_acc_list = [] # 存储间隔准确率acc-avg# 8 开始训练# 8.1 外部循环for epoch_idx in range(epochs):# 8.2 batch循环for i, (x, y) in enumerate(my_dataloader):# 8.3 将x送入到模型 一轮模型训练output, hn = my_rnn(x)my_loss = my_crossentropy(output, y)my_optimizer.zero_grad()my_loss.backward()my_optimizer.step()total_iter_num += 1total_loss += my_loss.item()item1 = 1 if torch.argmax(output, dim=-1).item() == y.item() else 0total_acc_num += item1# 每隔 100 步存储avg-loss acc-avgif total_iter_num % 100 == 0:# 保存一下平均损失loss_avg = total_loss / total_iter_numtotal_loss_list.append(loss_avg)# acc-avgacc_avg = total_acc_num / total_iter_numtotal_acc_list.append(acc_avg)if total_iter_num % 1000 == 0:loss_avg = total_loss / total_iter_numacc_avg = total_acc_num / total_iter_numend_time = time.time()use_time = end_time - start_timeprint('当前的训练批次:%d, 平均损失:%.5f, 训练时间:%.3f, 准确率:%.2f' % (epoch_idx + 1,loss_avg,use_time,acc_avg))# 9 保存模型torch.save(my_rnn.state_dict(), './model/my_rnn.bin')# 10 结束all_time = time.time() - start_timeprint('总耗时:', all_time)return total_loss_list, total_acc_list, all_time# 9 将模型结果进行保存,方便进行读取
def save_rnn_res():# 1 训练模型,得到需要的结果total_loss_list, total_acc_list, all_time = train_my_rnn()# 2 定义一个字典dict1 = {'loss': total_loss_list,'time': all_time,'acc': total_acc_list}# 3 保存成jsonwith open('./data/rnn_result.json', 'w') as fw:fw.write(json.dumps(dict1))# 10 读取模型结果json
def read_json(json_path):with open(json_path, 'r') as fr:# '{a:1, b:2,,,}' --> json.loads()# json.load() 加载json文件res = json.load(fr)return res# 11 绘图
def plt_RNN():# 1 拿到数据rnn_results = read_json('./data/rnn_result-epoch3.json')total_loss_list_rnn, all_time_rnn, total_acc_list_rnn = rnn_results['loss'], rnn_results['time'], rnn_results['acc']lstm_results = read_json('./data/lstm_result-epoch3.json')total_loss_list_lstm, all_time_lstm, total_acc_list_lstm = lstm_results['loss'], lstm_results['time'], lstm_results['acc']gru_results = read_json('./data/gru_result-epoch3.json')total_loss_list_gru, all_time_gru, total_acc_list_gru = gru_results['loss'], gru_results['time'], gru_results['acc']# 2 绘制loss对比曲线图plt.figure(0)plt.plot(total_loss_list_rnn, label='RNN')plt.plot(total_loss_list_lstm, label='LSTM', color='red')plt.plot(total_loss_list_gru, label='GRU', color='orange')plt.legend(loc='upper right')plt.savefig('./picture/loss.png')plt.show()# 3 绘制耗时柱状图plt.figure(1)x_data = ['RNN', 'LSTM', 'GRU']y_data = [all_time_rnn, all_time_lstm, all_time_gru]plt.bar(range(len(x_data)), y_data, tick_label=x_data)plt.savefig('./picture/use_time.png')plt.show()# 4 绘制acc曲线图plt.figure(2)plt.plot(total_acc_list_rnn, label='RNN')plt.plot(total_acc_list_lstm, label='LSTM', color='red')plt.plot(total_acc_list_gru, label='GRU', color='orange')plt.legend(loc='upper right')plt.savefig('./picture/acc.png')plt.show()# 12 定义预测输入的x --》 tensor_x
def line2tensor(x):tensor_x = torch.zeros(len(x), n_letter)for li, letter in enumerate(x):tensor_x[li][all_letters.find(letter)] = 1return tensor_x# 13 预测主函数
def rnn_predict(x):# 1 x --》 tensor_xtensor_x = line2tensor(x)# 2 实力化模型my_rnn = MyRNN(input_size=57, hidden_size=128, output_size=18)my_rnn.load_state_dict(torch.load('./model/my_rnn.bin'))# 3 预测with torch.no_grad(): # 预测时不去计算梯度input0 = tensor_x.unsqueeze(0) # input0 是三维的,rnn需要output, hn = my_rnn(input0)topv, topi = output.topk(3, 1, True)print('人名是', x)# 4 打印topk个for i in range(3):value = topv[0][i]index = topi[0][i]cate = categorys[index]print('国家名是:', cate)if __name__ == '__main__':# filename = './data/name_classfication.txt'# x, y = read_data(filename)# print(x)# print(y)# get_dataloader()# ceshiRNN()# train_my_rnn()# plt_RNN()rnn_predict('zhang')
相关文章:
人名分类器(nlp)
# coding: utf-8 import osos.environ[KMP_DUPLICATE_LIB_OK] True# 导入torch工具 import jsonimport torch # 导入nn准备构建模型 import torch.nn as nn import torch.nn.functional as F import torch.optim as optim # 导入torch的数据源 数据迭代器工具包 from torch.ut…...
斐波那契数列 相关问题 详解
斐波那契数列相关问题详解 斐波那契数列及其相关问题是算法学习中的经典主题,变形与应用非常广泛,涵盖了递推关系、动态规划、组合数学、数论等多个领域。以下是斐波那契数列的相关问题及其解法的详解。 1. 经典斐波那契数列 定义 初始条件࿱…...
Pytorch微调深度学习模型
在公开数据训练了模型,有时候需要拿到自己的数据上微调。今天正好做了一下微调,在此记录一下微调的方法。用Pytorch还是比较容易实现的。 网上找了很多方法,以及Chatgpt也给了很多方法,但是不够简洁和容易理解。 大体步骤是&…...
springboot 使用笔记
1.springboot 快速启动项目 注意:该启动只是临时启动,不能关闭终端面板 cd /www/wwwroot java -jar admin.jar2.脚本启动 linux shell脚本启动springboot服务 3.java一键部署springboot 第5条 https://blog.csdn.net/qq_30272167/article/details/1…...
网络安全基础——网络安全法
填空题 1.根据**《中华人民共和国网络安全法》**第二十条(第二款),任何组织和个人试用网路应当遵守宪法法律,遵守公共秩序,遵守社会公德,不危害网络安全,不得利用网络从事危害国家安全、荣誉和利益,煽动颠…...
SCAU软件体系结构实验四 组合模式
目录 一、题目 二、源码 一、题目 个人(Person)与团队(Team)可以形成一个组织(Organization):组织有两种:个人组织和团队组织,多个个人可以组合成一个团队,不同的个人与团队可以组合成一个更大的团队。 使用控制台或者JavaFx界面…...
Amazon商品详情API接口:电商创新与用户体验的驱动力
在电子商务蓬勃发展的今天,作为全球最大的电商平台之一,亚马逊(Amazon)凭借其强大的技术实力和丰富的商品资源,为全球用户提供了优质的购物体验。其中,Amazon商品详情API接口在电商创新与用户体验提升方面扮…...
手机无法连接服务器1302什么意思?
你有没有遇到过手机无法连接服务器,屏幕上显示“1302”这样的错误代码?尤其是在急需使用手机进行工作或联系朋友时,突然出现的连接问题无疑会带来不少麻烦。那么,什么是1302错误,它又意味着什么呢? 1302错…...
Android adb shell dumpsys audio 信息查看分析详解
Android adb shell dumpsys audio 信息查看分析详解 一、前言 Android 如果要分析当前设备的声音通道相关日志, 仅仅看AudioService的日志是看不到啥日志的,但是看整个audio关键字的日志又太多太乱了, 所以可以看一下系统提供的一个调试指令…...
Python 网络爬虫操作指南
网络爬虫是自动化获取互联网上信息的一种工具。它广泛应用于数据采集、分析以及实现信息聚合等众多领域。本文将为你提供一个完整的Python网络爬虫操作指南,帮助你从零开始学习并实现简单的网络爬虫。我们将涵盖基本的爬虫概念、Python环境配置、常用库介绍。 上传…...
基于FPGA的2FSK调制-串口收发-带tb仿真文件-实际上板验证成功
基于FPGA的2FSK调制 前言一、2FSK储备知识二、代码分析1.模块分析2.波形分析 总结 前言 设计实现连续相位 2FSK 调制器,2FSK 的两个频率为:fI15KHz,f23KHz,波特率为 1500 bps,比特0映射为f 载波,比特1映射为 载波。 1)…...
JavaScript的基础数据类型
一、JavaScript中的数组 定义 数组是一种特殊的对象,用于存储多个值。在JavaScript中,数组可以包含不同的数据类型,如数字、字符串、对象、甚至其他数组。数组的创建有两种常见方式: 字面量表示法:let fruits [apple…...
第三讲 架构详解:“隐语”可信隐私计算开源框架
目录 隐语架构 隐语架构拆解 产品层 算法层 计算层 资源层 互联互通 跨域管控 本文主要是记录参加隐语开源社区推出的第四期隐私计算实训营学习到的相关内容。 隐语架构 隐语架构拆解 产品层 产品定位: 通过可视化产品,降低终端用户的体验和演…...
JDBC编程---Java
目录 一、数据库编程的前置 二、Java的数据库编程----JDBC 1.概念 2.JDBC编程的优点 三.导入MySQL驱动包 四、JDBC编程的实战 1.创造数据源,并设置数据库所在的位置,三条固定写法 2.建立和数据库服务器之间的连接,连接好了后ÿ…...
Python绘制太极八卦
文章目录 系列目录写在前面技术需求1. 图形绘制库的支持2. 图形绘制功能3. 参数化设计4. 绘制控制5. 数据处理6. 用户界面 完整代码代码分析1. rset() 函数2. offset() 函数3. taiji() 函数4. bagua() 函数5. 绘制过程6. 技术亮点 写在后面 系列目录 序号直达链接爱心系列1Pyth…...
Spring框架特性及包下载(Java EE 学习笔记04)
1 Spring 5的新特性 Spring 5是Spring当前最新的版本,与历史版本对比,Spring 5对Spring核心框架进行了修订和更新,增加了很多新特性,如支持响应式编程等。 更新JDK基线 因为Spring 5代码库运行于JDK 8之上,所以Spri…...
Linux关于vim的笔记
Linux关于vim的笔记:(vimtutor打开vim 教程) --------------------------------------------------------------------------------------------------------------------------------- 1. 光标在屏幕文本中的移动既可以用箭头键,也可以使用 hjkl 字母键…...
linux mount nfs开机自动挂载远程目录
要在Linux系统中实现开机自动挂载NFS共享目录,你需要编辑/etc/fstab文件。以下是具体步骤和示例: 确保你的系统已经安装了NFS客户端。如果没有安装,可以使用以下命令安装: sudo apt-install nfs-common 编辑/etc/fstab文件&#…...
【vue】导航守卫
什么是导航守卫 在vue路由切换过程中对行为做个限制 全局前置守卫 route.beforeEach((to, from, next)) > {// to是切换到的路由// from是正要离开的路由// next控制是否允许进入目标路由next(false); //不允许 }路由级别的导航守卫 const routes [{path: /User,name: U…...
基于Matlab实现LDPC编码
在无线通信和数据存储领域,LDPC(低密度奇偶校验码)编码是一种高效、纠错能力强大的错误校正技术。本MATLAB仿真程序全面地展示了如何在AWGN(加性高斯白噪声)信道下应用LDPC编码与BPSK(二进制相移键控&#…...
FigmaCN:解决Figma英文界面障碍的设计师专属中文方案
FigmaCN:解决Figma英文界面障碍的设计师专属中文方案 【免费下载链接】figmaCN 中文 Figma 插件,设计师人工翻译校验 项目地址: https://gitcode.com/gh_mirrors/fi/figmaCN 作为一名设计师,您是否曾因Figma全英文界面而减慢工作流程&…...
从零到一:深度解析BertTokenizer.from_pretrained的加载机制与实战技巧
1. 初识BertTokenizer.from_pretrained:你的NLP敲门砖 第一次接触Hugging Face的Transformers库时,我被BertTokenizer.from_pretrained()这个方法深深吸引了。它就像是一把万能钥匙,能快速打开各种预训练语言模型的大门。记得当时我尝试用传统…...
终极无损视频剪辑神器:LosslessCut完整指南与5大实用技巧
终极无损视频剪辑神器:LosslessCut完整指南与5大实用技巧 【免费下载链接】lossless-cut The swiss army knife of lossless video/audio editing 项目地址: https://gitcode.com/gh_mirrors/lo/lossless-cut 你是否曾因视频剪辑导致画质下降而烦恼ÿ…...
RePKG终极指南:Wallpaper Engine资源提取与转换的完整解决方案
RePKG终极指南:Wallpaper Engine资源提取与转换的完整解决方案 【免费下载链接】repkg Wallpaper engine PKG extractor/TEX to image converter 项目地址: https://gitcode.com/gh_mirrors/re/repkg 你是否曾经遇到过这样的问题?在Wallpaper Eng…...
vue基于springboot的高校二手书交易系统
目录同行可拿货,招校园代理 ,本人源头供货商功能模块分析交易流程模块后台管理模块技术实现要点扩展功能建议项目技术支持源码获取详细视频演示 :文章底部获取博主联系方式!同行可合作同行可拿货,招校园代理 ,本人源头供货商 功能模块分析 用户管理模块…...
手把手教你用Arm Cortex-A715手册:从RAS到调试,一份给芯片设计者的实战笔记
Cortex-A715实战指南:芯片设计者的RAS与调试技术精要 在当今高性能计算领域,Arm Cortex-A715处理器核心凭借其卓越的能效比和性能表现,已成为众多芯片设计项目的首选。本文将从工程实践角度,深入剖析Cortex-A715的两个关键子系统&…...
STM32CubeMX实战指南:从零搭建HAL库项目与LED控制
1. STM32CubeMX与HAL库开发入门 第一次接触STM32开发的朋友可能会被各种专业术语吓到——寄存器、固件库、HAL库、时钟树配置... 作为一个从51单片机转战STM32的"过来人",我完全理解这种困惑。三年前我刚开始用STM32F103时,光是搭建开发环境就…...
Phi-4-mini-reasoning效果展示:Chainlit中实时显示推理耗时与token生成速率
Phi-4-mini-reasoning效果展示:Chainlit中实时显示推理耗时与token生成速率 1. 模型简介 Phi-4-mini-reasoning 是一个基于合成数据构建的轻量级开源模型,专注于高质量、密集推理的数据处理。作为Phi-4模型家族的一员,它特别强化了数学推理…...
OpenLara最佳实践:开发高质量游戏引擎的10个关键原则
OpenLara最佳实践:开发高质量游戏引擎的10个关键原则 【免费下载链接】OpenLara Classic Tomb Raider open-source engine 项目地址: https://gitcode.com/gh_mirrors/op/OpenLara OpenLara作为一款经典古墓丽影开源引擎,凭借跨平台设计和高效渲染…...
Ryujinx:C编写的Nintendo Switch模拟器技术解析与应用指南
Ryujinx:C#编写的Nintendo Switch模拟器技术解析与应用指南 【免费下载链接】Ryujinx 用 C# 编写的实验性 Nintendo Switch 模拟器 项目地址: https://gitcode.com/GitHub_Trending/ry/Ryujinx Ryujinx是一款用C#编写的实验性Nintendo Switch模拟器ÿ…...
