当前位置: 首页 > news >正文

word2vector训练代码详解

目录

1.代码实现

2.知识点 


 

1.代码实现

#导包
import math
import torch
from torch import nn
import dltools
#加载PTB数据集  ,需要把PTB数据集的文件夹放在代码上一级目录的data文件中,不用解压
#批次大小、窗口大小、噪声词大小
batch_size, max_window_size, num_noise_words = 512, 5, 5  
#获取数据集迭代器、词汇表
data_iter, vocab = dltools.load_data_ptb(batch_size, max_window_size, num_noise_words)
#讲解嵌入层embedding的用法(此行代码无用)#嵌入层
#通过嵌入层来获取skip—gram的中心词向量和上下文词向量
embed = nn.Embedding(num_embeddings=20, embedding_dim=4)  
# num_embeddings就是词表大小
# X的shape=(batch_size, num_steps)
# --one_hot编码--->(batch_size, num_steps, num_embedding(vocab_size))
# --点乘中心词矩阵-->(batch_size, num_steps, embed_size)
embed.weight.shape   #讲解嵌入层embedding的用法(此行代码无用)
torch.Size([20, 4])

embedding层先one_hot编码,再进行与embedding层的矩阵(num_embeddings,embedding_dim)乘法 

#构造skip_gram的前向传播
def skip_gram(center, contexts_and_negatives, embed_v, embed_u):"""embed_v:表示对中心词进行embedding层embed_u:对上下文词进行embedding层 """v = embed_v(center)                 #中心词的词向量表达u = embed_u(contexts_and_negatives) #上下文词的词向量表达#用中心词来预测上下文词#u_shape = (batch_size, num_steps, embed_size)---->(batch_size, embed_size, num_steps)进行矩阵乘法pred = torch.bmm(v, u.permute(0, 2, 1))  #矩阵乘法(bmm三维乘法),不用管batch_size维度return pred
#假设数据
skip_gram(torch.ones((2, 1), dtype=torch.long), torch.ones((2, 4), dtype=torch.long), embed, embed)
tensor([[[3.1980, 3.1980, 3.1980, 3.1980]],[[3.1980, 3.1980, 3.1980, 3.1980]]], grad_fn=<BmmBackward0>)
#假设数据
skip_gram(torch.ones((2, 1), dtype=torch.long), torch.ones((2, 4), dtype=torch.long), embed, embed).shape

 torch.Size([2, 1, 4])

#带掩码的二元交叉熵损失
class SigmoidBCELoss(nn.Module):def __init__(self):super().__init__()  #直接继承父类的初始化属性和方法def forward(self, inputs, target, mask=None):#nn.functional.binary_cross_entropy_with_logits表示返回的不是转化后的概率,是原始计算的数据结果#weight=mask权重将掩码带上#reduction='none'表示不将计算结果聚合,算损失时(默认聚合)out = nn.functional.binary_cross_entropy_with_logits(inputs, target, weight=mask, reduction='none')return out.mean(dim=1)  #计算结果是二维的,在索引1维度上聚合求平均
loss = SigmoidBCELoss()
[[1.1, -2.2, 3.3, -4.4]] * 2
[[1.1, -2.2, 3.3, -4.4], [1.1, -2.2, 3.3, -4.4]]
torch.tensor([[1.1, -2.2, 3.3, -4.4]] * 2).shape

 torch.Size([2, 4])

#假设数据测试
pred = torch.tensor([[1.1, -2.2, 3.3, -4.4]] * 2)
label = torch.tensor([[1.0, 0.0, 0.0, 0.0], [0.0, 1.0, 0.0, 0.0]])
mask = torch.tensor([[1, 1, 1, 1], [1, 1, 0, 0]])
#mask每一行都有4个数值,所以* mask.shape[1]=4
#但是mask中的数值0表示权重,是补充步长的,不重要,需要计算有效序列的损失平均值,所以 / mask.sum(axis=1)
loss(pred, label, mask) * mask.shape[1] / mask.sum(axis=1)

 tensor([0.9352, 1.8462])

#初始化模型参数,定义两个嵌入层
#一开始,embed_weights会标准正态分布的数据初始化
#两个embedding层的参数不一样,不能重复使用,需要初始化定义两个
embed_size = 100
net = nn.Sequential(nn.Embedding(num_embeddings=len(vocab), embedding_dim=embed_size),nn.Embedding(num_embeddings=len(vocab), embedding_dim=embed_size))

 

#定义训练过程
def train(net, data_iter, lr, num_epochs, device=dltools.try_gpu()):#修改embedding层的初始化方法,使用nn.init.xavier_uniform_初始化embed.weight权重,在NLP中不使用标准正态分布的额数据初始化权重def init_weights(m):if type(m) == nn.Embedding:nn.init.xavier_uniform_(m.weight)net.apply(init_weights)  net = net.to(device)#设置梯度下降的优化器optimizer = torch.optim.Adam(net.parameters(), lr=lr)#设置绘制可视化的动图(epoch——loss)animator = dltools.Animator(xlabel='epoch', ylabel='loss', xlim=[1, num_epochs])#设置累加metric = dltools.Accumulator(2)   #2种数据需要累加for epoch in range(num_epochs):  #遍历训练次数#设置计时器, 赋值批次数量timer, num_batches = dltools.Timer(), len(data_iter)    #data_iter是分好批次的数据集,长度就是批次数量num_batchesfor i, batch in enumerate(data_iter):   #i是索引, batch是取出的一批批数据#梯度清零optimizer.zero_grad()#接收中心词, 上下文词_噪声词, 掩码, 标记目标值 center, context_negative, mask, label = [data.to(device) for data in batch]#调用skip_gram模型预测pred = skip_gram(center, context_negative, embed_v=net[0], embed_u=net[1])#计算损失l = loss(pred.reshape(label.shape).float(), label.float(), mask) / mask.shape[1] * mask.sum(dim=1)#用loss反向传播  ,loss先sum()聚合变成标量(合并成一个数值), 只有标量才能反向传播l.sum().backward()#梯度更新optimizer.step()#累加metric.add(l.sum(), l.numel())   #l.sum()数值求和累加, l.numel()数量累加#   %  取余数      #  //  商向下取整#迭代到总数据量的5%的倍数时 或者 处理到最后一批数据时,执行下面操作#  i+1是因为i是从0开始遍历的if (i + 1) % (num_batches // 5) == 0 or i == num_batches - 1:  #epoch + (i+1) / num_batches当前迭代次数占整个数据集的比例animator.add(epoch + (i+1) / num_batches, (metric[0] / metric[1]))print(f'loss {metric[0] / metric[1]:.3f}', f'{metric[1] / timer.stop():.1f} tokens/sec on {str(device)}')      
lr, num_epochs = 0.002, 50
train(net, data_iter, lr, num_epochs)

#如果能够找到词的近义词, 就说明训练的不错
def get_similar_tokens(query_token, k, embed):"""query_token:需要预测的词k:最高相似度的词数量embed:embedding层的哪一层"""#获取词向量权重    (词向量权重*词的one_hot编码,就是词向量)W = embed.weight.dataprint(f'W的shape:{W.shape}')x = W[vocab[query_token]]     #embedding层是按照索引查表查词对应的权重-->优点print(f'x的shape:{x.shape}')#计算余弦相似度#torch.mv两个向量的点乘cos = torch.mv(W, x) / torch.sqrt(torch.sum(W * W, dim=1) * torch.sum(x * x) + 1e-9)print(f'cos的shape:{cos.shape}')#排序选择前k个对应的索引topk = torch.topk(cos, k=k+1)[1].cpu().numpy().astype('int32')for i in topk[1:]:   #排除query_token他本身,自己与自己余弦相似度最高print(f'cosine sim={float(cos[i]):.3f}:{vocab.to_tokens(i)}')
get_similar_tokens('food', 3, net[0])

 

W的shape:torch.Size([6719, 100])
x的shape:torch.Size([100])
cos的shape:torch.Size([6719])
cosine sim=0.430:feed
cosine sim=0.418:precious
cosine sim=0.412:drink

2.知识点 

 

相关文章:

word2vector训练代码详解

目录 1.代码实现 2.知识点 1.代码实现 #导包 import math import torch from torch import nn import dltools #加载PTB数据集 &#xff0c;需要把PTB数据集的文件夹放在代码上一级目录的data文件中&#xff0c;不用解压 #批次大小、窗口大小、噪声词大小 batch_size, ma…...

Python的风格应该是怎样的?除语法外,有哪些规范?

写代码不那么pythonic风格的&#xff0c;多多少少都会让人有点难受。 什么是pythonic呢&#xff1f;简而言之&#xff0c;这是一种写代码时遵守的规范&#xff0c;主打简洁、清晰、可读性高&#xff0c;符合PEP 8&#xff08;Python代码样式指南&#xff09;约定的模式。 Pyth…...

net core mvc 数据绑定 《1》

其它的绑定 跟net mvc 一样 》》MVC core 、framework 一样 1 模型绑定数组类型 2 模型绑定集合类型 3 模型绑定复杂的集合类型 4 模型绑定源 》》》》 模型绑定 使用输入数据的原生请求集合是可以工作的【request[],Querystring,request.from[]】&#xff0c; 但是从可读…...

python为姓名注音实战案例

有如下数据&#xff0c;需要对名字注音。 数据样例&#xff1a;&#x1f447; 一、实现过程 前提条件&#xff1a;由于会用到pypinyin库&#xff0c;所以一定得提前安装。 pip install pypinyin1、详细代码&#xff1a; from pypinyin import pinyin, Style# 输入数据 names…...

MATLAB中的艺术:用爱心形状控制坐标轴

在MATLAB中&#xff0c;坐标轴控制是绘图和数据可视化中的一个重要方面。通过精细地管理坐标轴&#xff0c;我们不仅可以改善图形的视觉效果&#xff0c;还可以赋予图形更深的情感寓意。本文将介绍如何在MATLAB中使用坐标轴控制来绘制一个爱心形状&#xff0c;并探讨其背后的技…...

基于mybatis-plus创建springboot,添加增删改查功能,使用postman来测试接口出现的常见错误

1 当你在使用postman检测 添加和更新功能时&#xff0c;报了一个500错误 查看idea发现是&#xff1a; Data truncation: Out of range value for column id at row 1 通过翻译&#xff1a;数据截断&#xff1a;表单第1行的“id”列出现范围外值。一般情况下&#xff0c;出现这个…...

Java:Object操作

目录 1、Object转List对象2、Object转实体对象 1、Object转List对象 List<User> userList MtUtils.ObjectToList(objData, User.class);/*** Object对象转 List集合** param object Object对象* param clazz 需要转换的集合* param <T> 泛型类* return*/ public s…...

Java-并发基础

启动线程的方式 只有&#xff1a; 1、X extends Thread;&#xff0c;然后X.start 2、X implements Runnable&#xff1b;然后交给Thread运行 有争议可以可以查看 Thread源码的注释&#xff1a; There are two ways to create a new thread of execution.Callable的方式需要…...

速盾:网页游戏部署高防服务器有什么优势?

在当前互联网发展的背景下&#xff0c;网页游戏的市场需求不断增长&#xff0c;相应地带来了对高防服务器的需求。高防服务器可以为网页游戏部署提供许多优势&#xff0c;下面就详细介绍一下。 第一&#xff0c;高防服务器具有强大的抗DDoS攻击能力。DDoS攻击是目前互联网上最…...

【从0开始自动驾驶】ros2编写自定义消息 msg文件和msg文件嵌套

【从0开始自动驾驶】ros2编写自定义消息 msg文件和msg文件嵌套 详细解答和讨论请私信在工作空间内新建一个功能包在msg内创建对应的msg文件创建名为TestMsg.msg的文件创建名为TestSubMsg.msg的文件&#xff08;在前一个msg文件中引用&#xff09;修改CmakeList.txt修改package.…...

docker 部署 Seatunnel 和 Seatunnel Web

docker 部署 Seatunnel 和 Seatunnel Web 说明&#xff1a; 部署方式前置条件&#xff0c;已经在宿主机上运行成功运行文件采用挂载宿主机目录的方式部署SeaTunnel Engine 采用的是混合模式集群 编写Dockerfile并打包镜像 Seatunnel FROM openjdk:8 WORKDIR /opt/seatunne…...

【易上手快捷开发新框架技术】nicegui标签组件lable用法庖丁解牛深度解读和示例源代码IDE运行和调试通过截图为证

传奇开心果微博文系列 序言一、标签组件lable最基本用法示例1.在网页上显示出 Hello World 的标签示例2. 使用 style 参数改变标签样式示例 二、标签组件lable更多用法示例1. 添加按钮动态修改标签文字2. 点击按钮动态改变标签内容、颜色、大小和粗细示例代码3. 添加开关组件动…...

从HarmonyOS Next导出手机照片

1&#xff09;打开DevEco Studio开发工具 2&#xff09;插入USB数据线&#xff0c;连接手机 3&#xff09;在DevEco Studio开发工具&#xff0c;通过View -> Tool Windows -> Device File Browser打开管理工具 4&#xff09;选择storage -> cloud -> 100->fi…...

[Docker学习笔记]Docker的原理Docker常见命令

文章目录 什么是DockerDocker的优势Docker的原理Docker 的安装Docker 的 namespaces Docker的常见命令docker version:查看版本信息docker info 查看docker详细信息我们关注的信息 docker search:镜像搜索docker pull:镜像拉取到本地docker push:推送本地镜像到镜像仓库docker …...

【ESP 保姆级教程】小课设篇 —— 案例:20240507_esp01s+UNO的智能浇水系统

忘记过去,超越自己 ❤️ 博客主页 单片机菜鸟哥,一个野生非专业硬件IOT爱好者 ❤️❤️ 本篇创建记录 2024-09-30 ❤️❤️ 本篇更新记录 2023-09-30 ❤️🎉 欢迎关注 🔎点赞 👍收藏 ⭐️留言📝🙏 此博客均由博主单独编写,不存在任何商业团队运营,如发现错误,请…...

如何设置MySQL分布式架构主键ID,为什么不能使用自增ID或者UUID做主键?

MySQL分布式架构主键ID的设置方法 雪花算法&#xff08;Snowflake&#xff09; 原理&#xff1a;雪花算法是一种生成分布式唯一ID的算法。它由64位二进制数组成&#xff0c;结构如下&#xff1a;1位符号位&#xff08;固定为0&#xff09; 41位时间戳&#xff08;表示从一个固…...

服务器虚拟化详解

服务器虚拟化详解 服务器虚拟化是一种将物理服务器资源转化为虚拟服务器资源的技术&#xff0c;它允许在一台物理服务器上运行多个虚拟服务器&#xff0c;每个虚拟服务器都拥有独立的操作系统、应用程序和资源配置。这种技术极大地提高了服务器的利用率、灵活性和可扩展性&…...

医疗陪诊APP开发实战:从互联网医院系统源码开始

本文将从互联网医院系统源码出发&#xff0c;深入探讨医疗陪诊APP的开发实战。 一、从互联网医院系统源码入手 开发医疗陪诊APP的基础在于互联网医院系统的源码。互联网医院系统通常包括以下几个模块&#xff1a; 1.用户管理&#xff1a;用户注册、登录、信息管理等功能。 …...

jenkins 构建报错ERROR: Error fetching remote repo ‘origin‘

问题描述 修改项目的仓库地址后&#xff0c;使用jenkins构建报错 Running as SYSTEM Building in workspace /var/jenkins_home/workspace/【测试】客户端/client-fonchain-main The recommended git tool is: NONE using credential 680a5841-cfa5-4d8a-bb38-977f796c26dd&g…...

初识C#(三)- 数组

我有17栋楼&#xff0c;在不同地域&#xff0c;都是不同价格租出去给不同的人~ 文章目录 前言一、数组1.1 我有17栋楼 - 数组的声明1.2 包租公&包租婆 - 数组赋值1.3 每个月都要交租的苦逼租客 - 数组的使用 二、字符串2.1 字符串的使用方法 总结 前言 本篇笔记重点描述C#…...

【Axure高保真原型】引导弹窗

今天和大家中分享引导弹窗的原型模板&#xff0c;载入页面后&#xff0c;会显示引导弹窗&#xff0c;适用于引导用户使用页面&#xff0c;点击完成后&#xff0c;会显示下一个引导弹窗&#xff0c;直至最后一个引导弹窗完成后进入首页。具体效果可以点击下方视频观看或打开下方…...

Linux简单的操作

ls ls 查看当前目录 ll 查看详细内容 ls -a 查看所有的内容 ls --help 查看方法文档 pwd pwd 查看当前路径 cd cd 转路径 cd .. 转上一级路径 cd 名 转换路径 …...

中医有效性探讨

文章目录 西医是如何发展到以生物化学为药理基础的现代医学&#xff1f;传统医学奠基期&#xff08;远古 - 17 世纪&#xff09;近代医学转型期&#xff08;17 世纪 - 19 世纪末&#xff09;​现代医学成熟期&#xff08;20世纪至今&#xff09; 中医的源远流长和一脉相承远古至…...

Python基于历史模拟方法实现投资组合风险管理的VaR与ES模型项目实战

说明&#xff1a;这是一个机器学习实战项目&#xff08;附带数据代码文档&#xff09;&#xff0c;如需数据代码文档可以直接到文章最后关注获取。 1.项目背景 在金融市场日益复杂和波动加剧的背景下&#xff0c;风险管理成为金融机构和个人投资者关注的核心议题之一。VaR&…...

Kubernetes 节点自动伸缩(Cluster Autoscaler)原理与实践

在 Kubernetes 集群中&#xff0c;如何在保障应用高可用的同时有效地管理资源&#xff0c;一直是运维人员和开发者关注的重点。随着微服务架构的普及&#xff0c;集群内各个服务的负载波动日趋明显&#xff0c;传统的手动扩缩容方式已无法满足实时性和弹性需求。 Cluster Auto…...

规则与人性的天平——由高考迟到事件引发的思考

当那位身着校服的考生在考场关闭1分钟后狂奔而至&#xff0c;他涨红的脸上写满绝望。铁门内秒针划过的弧度&#xff0c;成为改变人生的残酷抛物线。家长声嘶力竭的哀求与考务人员机械的"这是规定"&#xff0c;构成当代中国教育最尖锐的隐喻。 一、刚性规则的必要性 …...

2025.6.9总结(利与弊)

凡事都有两面性。在大厂上班也不例外。今天找开发定位问题&#xff0c;从一个接口人不断溯源到另一个 接口人。有时候&#xff0c;不知道是谁的责任填。将工作内容分的很细&#xff0c;每个人负责其中的一小块。我清楚的意识到&#xff0c;自己就是个可以随时替换的螺丝钉&…...

Redis上篇--知识点总结

Redis上篇–解析 本文大部分知识整理自网上&#xff0c;在正文结束后都会附上参考地址。如果想要深入或者详细学习可以通过文末链接跳转学习。 1. 基本介绍 Redis 是一个开源的、高性能的 内存键值数据库&#xff0c;Redis 的键值对中的 key 就是字符串对象&#xff0c;而 val…...

shell脚本质数判断

shell脚本质数判断 shell输入一个正整数,判断是否为质数(素数&#xff09;shell求1-100内的质数shell求给定数组输出其中的质数 shell输入一个正整数,判断是否为质数(素数&#xff09; 思路&#xff1a; 1:1 2:1 2 3:1 2 3 4:1 2 3 4 5:1 2 3 4 5-------> 3:2 4:2 3 5:2 3…...

npm安装electron下载太慢,导致报错

npm安装electron下载太慢&#xff0c;导致报错 背景 想学习electron框架做个桌面应用&#xff0c;卡在了安装依赖&#xff08;无语了&#xff09;。。。一开始以为node版本或者npm版本太低问题&#xff0c;调整版本后还是报错。偶尔执行install命令后&#xff0c;可以开始下载…...