当前位置: 首页 > news >正文

pytorch实现基于Word2Vec的词嵌入

PyTorch 实现 Word2Vec(Skip-gram 模型) 的完整代码,使用 中文语料 进行训练,包括数据预处理、模型定义、训练和测试


1. 主要特点

支持中文数据,基于 jieba 进行分词
使用 Skip-gram 进行训练,适用于小数据集
支持负采样,提升训练效率
使用 cosine similarity 计算相似单词

完整代码:

import torch
import torch.nn as nn
import torch.optim as optim
import numpy as np
import random
import jieba
from collections import Counter
from sklearn.metrics.pairwise import cosine_similarity# ========== 1. 数据预处理 ==========
corpus = ["我们 喜欢 深度 学习","自然 语言 处理 是 有趣 的","人工智能 改变 了 世界","深度 学习 是 人工智能 的 重要 组成部分"
]# 超参数
window_size = 2      # 窗口大小
embedding_dim = 10   # 词向量维度
num_epochs = 100     # 训练轮数
learning_rate = 0.01 # 学习率
batch_size = 4       # 批大小
neg_samples = 5      # 负采样个数# 分词 & 构建词汇表
tokenized_corpus = [list(jieba.cut(sentence)) for sentence in corpus]
vocab = set(word for sentence in tokenized_corpus for word in sentence)
word2idx = {word: idx for idx, word in enumerate(vocab)}
idx2word = {idx: word for word, idx in word2idx.items()}# 统计词频
word_counts = Counter([word for sentence in tokenized_corpus for word in sentence])
total_words = sum(word_counts.values())# 计算负采样概率
word_freqs = {word: count / total_words for word, count in word_counts.items()}
word_powers = {word: freq ** 0.75 for word, freq in word_freqs.items()}
Z = sum(word_powers.values())
word_distribution = {word: prob / Z for word, prob in word_powers.items()}# 负采样函数
def negative_sampling(positive_word, num_samples=5):words = list(word_distribution.keys())probabilities = list(word_distribution.values())negatives = []while len(negatives) < num_samples:neg = np.random.choice(words, p=probabilities)if neg != positive_word:negatives.append(neg)return negatives# 生成 Skip-gram 训练数据
data = []
for sentence in tokenized_corpus:indices = [word2idx[word] for word in sentence]for center_idx in range(len(indices)):center_word = indices[center_idx]for offset in range(-window_size, window_size + 1):context_idx = center_idx + offsetif 0 <= context_idx < len(indices) and context_idx != center_idx:context_word = indices[context_idx]data.append((center_word, context_word))# 转换为 PyTorch 张量
data = [(torch.tensor(center), torch.tensor(context)) for center, context in data]# ========== 2. 定义 Word2Vec (Skip-gram) 模型 ==========
class Word2Vec(nn.Module):def __init__(self, vocab_size, embedding_dim):super(Word2Vec, self).__init__()self.embedding = nn.Embedding(vocab_size, embedding_dim)self.output_layer = nn.Linear(embedding_dim, vocab_size)def forward(self, center_word):embed = self.embedding(center_word)  # 获取中心词向量out = self.output_layer(embed)       # 计算词分布return out# 初始化模型
model = Word2Vec(len(vocab), embedding_dim)# ========== 3. 训练 Word2Vec ==========
criterion = nn.CrossEntropyLoss()
optimizer = optim.Adam(model.parameters(), lr=learning_rate)for epoch in range(num_epochs):total_loss = 0random.shuffle(data)  # 每轮打乱数据for center_word, context_word in data:optimizer.zero_grad()output = model(center_word.unsqueeze(0))  # 预测词分布loss = criterion(output, context_word.unsqueeze(0))  # 计算损失loss.backward()optimizer.step()total_loss += loss.item()if (epoch + 1) % 10 == 0:print(f"Epoch [{epoch+1}/{num_epochs}], Loss: {total_loss:.4f}")# ========== 4. 测试词向量 ==========
word_vectors = model.embedding.weight.data.numpy()# 计算单词相似度
def most_similar(word, top_n=3):if word not in word2idx:return "单词不在词汇表中"word_vec = word_vectors[word2idx[word]].reshape(1, -1)similarities = cosine_similarity(word_vec, word_vectors)[0]# 获取相似度最高的 top_n 个单词(排除自身)similar_idx = similarities.argsort()[::-1][1:top_n+1]return [(idx2word[idx], similarities[idx]) for idx in similar_idx]# 测试相似词
test_words = ["深度", "学习", "人工智能"]
for word in test_words:print(f"【{word}】的相似单词:", most_similar(word))

数据预处理
  • 使用 jieba.cut() 进行分词
  • 创建 word2idxidx2word
  • 使用滑动窗口生成 (中心词, 上下文词) 训练样本
  • 实现 negative_sampling() 提高训练效率
模型
  • Embedding 学习词向量
  • Linear 计算单词的概率分布
  • CrossEntropyLoss 计算目标词与预测词的匹配度
  • 使用 Adam 进行梯度更新
计算词相似度
  • 使用 cosine_similarity 计算词向量相似度
  • 找出 top_n 个最相似的单词

 5. 可优化点

 使用更大的中文语料库(如 THUCNews
 使用 t-SNE 进行词向量可视化
增加负采样,提升模型训练效率

相关文章:

pytorch实现基于Word2Vec的词嵌入

PyTorch 实现 Word2Vec&#xff08;Skip-gram 模型&#xff09; 的完整代码&#xff0c;使用 中文语料 进行训练&#xff0c;包括数据预处理、模型定义、训练和测试。 1. 主要特点 支持中文数据&#xff0c;基于 jieba 进行分词 使用 Skip-gram 进行训练&#xff0c;适用于小数…...

流媒体娱乐服务平台在AWS上使用Presto作为大数据的交互式查询引擎的具体流程和代码

一家流媒体娱乐服务平台拥有庞大的用户群体和海量的数据。为了高效处理和分析这些数据&#xff0c;它选择了Presto作为其在AWS EMR上的大数据查询引擎。在AWS EMR上使用Presto取得了显著的成果和收获。这些成果不仅提升了数据查询效率&#xff0c;降低了运维成本&#xff0c;还…...

鸿蒙 循环控制 简单用法

效果 简单使用如下&#xff1a; class Item {id: numbername: stringprice: numberimg: stringdiscount: numberconstructor(id: number, name: string, price: number, img: string, discount: number) {this.id idthis.name namethis.price pricethis.img imgthis.discou…...

四、GPIO中断实现按键功能

4.1 GPIO简介 输入输出&#xff08;I/O&#xff09;是一个非常重要的概念。I/O泛指所有类型的输入输出端口&#xff0c;包括单向的端口如逻辑门电路的输入输出管脚和双向的GPIO端口。而GPIO&#xff08;General-Purpose Input/Output&#xff09;则是一个常见的术语&#xff0c…...

Linux安装zookeeper

1, 下载 Apache ZooKeeperhttps://zookeeper.apache.org/releases.htmlhttps://zookeeper.apache.org/releases.htmlhttps://zookeeper.apache.org/releases.htmlhttps://zookeeper.apache.org/releases.htmlhttps://zookeeper.apache.org/releases.htmlhttps://zookeeper.apa…...

【贪心算法篇】:“贪心”之旅--算法练习题中的智慧与策略(二)

✨感谢您阅读本篇文章&#xff0c;文章内容是个人学习笔记的整理&#xff0c;如果哪里有误的话还请您指正噢✨ ✨ 个人主页&#xff1a;余辉zmh–CSDN博客 ✨ 文章所属专栏&#xff1a;贪心算法篇–CSDN博客 文章目录 前言例题1.买卖股票的最佳时机2.买卖股票的最佳时机23.k次取…...

007 JSON Web Token

文章目录 https://doc.hutool.cn/pages/jwt/#jwt%E4%BB%8B%E7%BB%8D JWT是一种用于双方之间安全传输信息的简洁的、URL安全的令牌标准。这个标准由互联网工程任务组(IETF)发表&#xff0c;定义了一种紧凑且自包含的方式&#xff0c;用于在各方之间作为JSON对象安全地传输信息。…...

Windsurf cursor vscode+cline 与Python快速开发指南

Windsurf简介 Windsurf是由Codeium推出的全球首个基于AI Flow范式的智能IDE&#xff0c;它通过强大的AI助手功能&#xff0c;显著提升开发效率。Windsurf集成了先进的代码补全、智能重构、代码生成等功能&#xff0c;特别适合Python开发者使用。 Python环境配置 1. Conda安装…...

将markdown文件和LaTex公式转为word

通义千问等大模型生成的回答多数是markdown类型的&#xff0c;需要将他们转为Word文件 一 pypandoc 介绍 1. 项目介绍 pypandoc 是一个用于 pandoc 的轻量级 Python 包装器。pandoc 是一个通用的文档转换工具&#xff0c;支持多种格式的文档转换&#xff0c;如 Markdown、HTM…...

grpc 和 http 的区别---二进制vsJSON编码

gRPC 和 HTTP 是两种广泛使用的通信协议&#xff0c;各自适用于不同的场景。以下是它们的详细对比与优势分析&#xff1a; 一、核心特性对比 特性gRPCHTTP协议基础基于 HTTP/2基于 HTTP/1.1 或 HTTP/2数据格式默认使用 Protobuf&#xff08;二进制&#xff09;通常使用 JSON/…...

C#面向对象(封装)

1.什么是封装? C# 封装 封装 被定义为“把一个或多个项目封闭在一个物理的或者逻辑的包中”。 在面向对象程序设计方法论中&#xff0c;封装是为了防止对实现细节的访问。 抽象和封装是面向对象程序设计的相关特性。 抽象允许相关信息可视化&#xff0c;封装则使开发者实现所…...

kamailio-kamctl monitor解释

这段输出是 Kamailio 服务器的运行时信息和统计数据的摘要。以下是对每个部分的详细解释&#xff1a; 1. Kamailio Runtime Details cycle #: 3: 表示 Kamailio 的主循环已经运行了 3 个周期。Kamailio 是一个事件驱动的服务器&#xff0c;主循环用于处理事件和请求。if const…...

39. I2C实验

一、IIC协议详解 1、ALPHA开发板上有个AP3216C&#xff0c;这是一个IIC接口的器件&#xff0c;这是一个环境光传感器。AP3216C连接到了I2C1上: I2C1_SCL: 使用的是UART4_TXD这个IO&#xff0c;复用位ALT2 I2C1_SDA: 使用的是UART4_RXD这个IO。复用为ALT2 2、I2C分为SCL和SDA&…...

GPIO配置通用输出,推挽输出,开漏输出的作用,以及输出上下拉起到的作用

通用输出说明&#xff1a; ①输出原理&#xff1a; 对输出数据寄存器的对应位写0 或 1&#xff0c;就可以控制对应编号的IO口输出低/高电平 ②输出类型 推挽输出&#xff1a;IO口可以输出高电平&#xff0c;也可以输出低电平 开漏输出&#xff1a;IO口只能输出低电平 所以…...

Spring AOP 入门教程:基础概念与实现

目录 第一章&#xff1a;AOP概念的引入 第二章&#xff1a;AOP相关的概念 1. AOP概述 2. AOP的优势 3. AOP的底层原理 第三章&#xff1a;Spring的AOP技术 - 配置文件方式 1. AOP相关的术语 2. AOP配置文件方式入门 3. 切入点的表达式 4. AOP的通知类型 第四章&#x…...

DeepSeek 核心技术全景解析

DeepSeek 核心技术全景解析&#xff1a;突破性创新背后的设计哲学 DeepSeek的创新不仅仅是对AI基础架构的改进&#xff0c;更是一场范式革命。本文将深入剖析其核心技术&#xff0c;探讨 如何突破 Transformer 计算瓶颈、如何在 MoE&#xff08;Mixture of Experts&#xff09…...

90,【6】攻防世界 WEB Web_php_unserialize

进入靶场 进入靶场 <?php // 定义一个名为 Demo 的类 class Demo { // 定义一个私有属性 $file&#xff0c;默认值为 index.phpprivate $file index.php;// 构造函数&#xff0c;当创建类的实例时会自动调用// 接收一个参数 $file&#xff0c;用于初始化对象的 $file 属…...

实现网站内容快速被搜索引擎收录的方法

本文转自&#xff1a;百万收录网 原文链接&#xff1a;https://www.baiwanshoulu.com/6.html 实现网站内容快速被搜索引擎收录&#xff0c;是网站运营和推广的重要目标之一。以下是一些有效的方法&#xff0c;可以帮助网站内容更快地被搜索引擎发现和收录&#xff1a; 一、确…...

WSL2中安装的ubuntu搭建tftp服务器uboot通过tftp下载

Windows中安装wsl2&#xff0c;wsl2里安装ubuntu。 1. Wsl启动后 1&#xff09;Windows下ip ipconfig 以太网适配器 vEthernet (WSL (Hyper-V firewall)): 连接特定的 DNS 后缀 . . . . . . . : IPv4 地址 . . . . . . . . . . . . : 172.19.32.1 子网掩码 . . . . . . . .…...

机器学习优化算法:从梯度下降到Adam及其变种

机器学习优化算法&#xff1a;从梯度下降到Adam及其变种 引言 最近deepseek的爆火已然说明&#xff0c;在机器学习领域&#xff0c;优化算法是模型训练的核心驱动力。无论是简单的线性回归还是复杂的深度神经网络&#xff0c;优化算法的选择直接影响模型的收敛速度、泛化性能…...

终极指南:如何安全高效地使用APKMirror下载安卓应用

终极指南&#xff1a;如何安全高效地使用APKMirror下载安卓应用 【免费下载链接】APKMirror 项目地址: https://gitcode.com/gh_mirrors/ap/APKMirror APKMirror是一款专注于安卓应用安全下载与管理的开源工具&#xff0c;为你提供官方应用商店之外的可靠替代方案。通过…...

暗黑3终极按键助手D3KeyHelper:图形化配置解放你的双手

暗黑3终极按键助手D3KeyHelper&#xff1a;图形化配置解放你的双手 【免费下载链接】D3keyHelper D3KeyHelper是一个有图形界面&#xff0c;可自定义配置的暗黑3鼠标宏工具。 项目地址: https://gitcode.com/gh_mirrors/d3/D3keyHelper 还在为暗黑破坏神3中繁琐的技能按…...

国产LDO CN86L028实战:解决图像传感器电源噪声,兼容BL8062

1. 项目概述与核心需求解析最近在折腾一个老式录像机的修复与升级项目&#xff0c;目标很明确&#xff1a;提升其图像采集的稳定性。这台设备在运行中&#xff0c;画面时不时会出现条纹干扰&#xff0c;声音里也夹杂着微弱的底噪&#xff0c;尤其是在电源波动较大的环境下&…...

如何轻松掌握res-downloader:高效下载网络资源的终极指南

如何轻松掌握res-downloader&#xff1a;高效下载网络资源的终极指南 【免费下载链接】res-downloader 视频号、小程序、抖音、快手、小红书、直播流、m3u8、酷狗、QQ音乐等常见网络资源下载! 项目地址: https://gitcode.com/GitHub_Trending/re/res-downloader 你是否曾…...

Midscene.js技术架构深度解析:构建企业级视觉驱动自动化测试平台的技术挑战与解决方案

Midscene.js技术架构深度解析&#xff1a;构建企业级视觉驱动自动化测试平台的技术挑战与解决方案 【免费下载链接】midscene AI-powered, vision-driven UI automation for every platform. 项目地址: https://gitcode.com/GitHub_Trending/mid/midscene 在当今多平台、…...

城市规划师实战:如何用TransCad+四阶段法,为你的新区规划提供交通量支撑?

城市规划师实战&#xff1a;TransCad与四阶段法在新区交通规划中的深度应用 1. 从理论到实践&#xff1a;四阶段法的核心逻辑 在Z新城规划项目中&#xff0c;我们面临的核心挑战是如何科学预测未来15年的交通需求。四阶段法作为交通规划领域的经典方法论&#xff0c;其价值在于…...

Crustocean/conch:云原生容器化应用构建与部署的自动化工具箱

1. 项目概述与核心价值最近在折腾一个很有意思的项目&#xff0c;叫“Crustocean/conch”。光看这个名字&#xff0c;你可能觉得有点摸不着头脑&#xff0c;又是“甲壳海洋”又是“海螺”的。其实&#xff0c;这是一个非常典型的、由开发者社区驱动的开源项目命名风格&#xff…...

Midjourney V6树胶重铬酸盐输出崩溃?紧急修复指南(含--sref自定义光敏响应曲线参数实测数据)

更多请点击&#xff1a; https://intelliparadigm.com 第一章&#xff1a;Midjourney V6树胶重铬酸盐输出崩溃现象与本质溯源 现象复现与触发条件 Midjourney V6 在启用 --style raw 且 prompt 中包含化学术语&#xff08;如“重铬酸盐”、“树胶”、“potassium dichromate”…...

告别硬盘数据丢失焦虑!电脑专属5种恢复方法,无踩坑,速存

日常使用电脑时&#xff0c;文件误删是高频突发状况——辛苦整理的办公文档、珍藏的生活影像、重要的程序安装包&#xff0c;一旦不小心删除&#xff0c;难免让人手足无措。好在2026年&#xff0c;随着数据存储技术的迭代与恢复工具的升级&#xff0c;电脑误删文件的恢复成功率…...

5分钟掌握百度网盘高速下载神器:完全免费的开源解析工具终极指南

5分钟掌握百度网盘高速下载神器&#xff1a;完全免费的开源解析工具终极指南 【免费下载链接】baidu-wangpan-parse 获取百度网盘分享文件的下载地址 项目地址: https://gitcode.com/gh_mirrors/ba/baidu-wangpan-parse 还在为百度网盘非会员下载速度只有几十KB而烦恼吗…...