当前位置: 首页 > news >正文

论文辅助笔记:t2vec models.py

1 EncoderDecoder

1.1 _init_

class EncoderDecoder(nn.Module):def __init__(self, vocab_size, embedding_size,hidden_size, num_layers, dropout, bidirectional):super(EncoderDecoder, self).__init__()self.vocab_size = vocab_size #词汇表大小self.embedding_size = embedding_size #词向量嵌入的维度大小## the embedding shared by encoder and decoderself.embedding = nn.Embedding(vocab_size, embedding_size,padding_idx=constants.PAD)#词向量嵌入层self.encoder = Encoder(embedding_size, hidden_size, num_layers,dropout, bidirectional, self.embedding)#编码器self.decoder = Decoder(embedding_size, hidden_size, num_layers,dropout, self.embedding)#解码器self.num_layers = num_layers

1.2 load_pretrained_embedding

从指定的路径加载预训练的词嵌入权重,并将这些权重复制到模型中的 embedding

def load_pretrained_embedding(path):if os.path.isfile(path):w = torch.load(path)#加载预训练的嵌入权重到变量 wself.embedding.weight.data.copy_(w)#将加载的权重 w 复制到模型的嵌入层

1.3 encoder_hn2decoder_h0

'''
转换编码器的输出隐藏状态
'''
def encoder_hn2decoder_h0(self, h):"""Input:编码器的输出隐藏状态h (num_layers * num_directions, batch, hidden_size): encoder output hn---Output: 解码器的初始隐藏状态h (num_layers, batch, hidden_size * num_directions): decoder input h0"""if self.encoder.num_directions == 2:num_layers, batch, hidden_size = h.size(0)//2, h.size(1), h.size(2)#根据输入 h 的形状计算 num_layers, batch 和 hidden_sizereturn h.view(num_layers, 2, batch, hidden_size)\.transpose(1, 2).contiguous()\.view(num_layers, batch, hidden_size * 2)'''使用 view 方法将 h 重塑为形状 (num_layers, 2, batch, hidden_size)。这里的 2 对应于双向RNN的两个方向使用 transpose 交换第2和第3维使用 contiguous 确保张量在内存中是连续的使用 view 方法再次重塑张量,将两个方向的隐藏状态连接在一起,形成形状 (num_layers, batch, hidden_size * 2) 的张量'''else:return h

pytorch笔记:contiguous &tensor 存储知识_pytorch中的tensor存储是列主布局还是行主布局_UQI-LIUWJ的博客-CSDN博客 

1.4 forward

def forward(self, src, lengths, trg):"""Input:src (src_seq_len, batch): source tensor 源序列lengths (1, batch): source sequence lengths 源序列的长度trg (trg_seq_len, batch): target tensor, the `seq_len` in trg is notnecessarily the same as that in src 目标序列需要注意的是,目标序列的长度并不一定与源序列的长度相同---Output:output (trg_seq_len, batch, hidden_size)"""encoder_hn, H = self.encoder(src, lengths)#将源序列src和其长度lengths传递给编码器decoder_h0 = self.encoder_hn2decoder_h0(encoder_hn)#将编码器的输出隐藏状态encoder_hn转换为适合解码器的初始隐藏状态decoder_h0。## for target we feed the range [BOS:EOS-1] into decoderoutput, decoder_hn = self.decoder(trg[:-1], decoder_h0, H)return output

2 Encoder

2.1 init


class Encoder(nn.Module):def __init__(self, input_size, hidden_size, num_layers, dropout,bidirectional, embedding):"""embedding (vocab_size, input_size): pretrained embedding"""super(Encoder, self).__init__()self.num_directions = 2 if bidirectional else 1#根据 bidirectional 参数决定方向数量assert hidden_size % self.num_directions == 0self.hidden_size = hidden_size // self.num_directionsself.num_layers = num_layersself.embedding = embeddingself.rnn = nn.GRU(input_size, self.hidden_size,num_layers=num_layers,bidirectional=bidirectional,dropout=dropout)

2.2 forward

'''
数据在编码器中的传播方式,并且考虑了序列的真实长度以处理填充
'''
def forward(self, input, lengths, h0=None):"""Input:input (seq_len, batch): padded sequence tensorlengths (1, batch): sequence lengthsh0 (num_layers*num_directions, batch, hidden_size): initial hidden state---Output:hn (num_layers*num_directions, batch, hidden_size):the hidden state of each layeroutput (seq_len, batch, hidden_size*num_directions): output tensor"""# (seq_len, batch) => (seq_len, batch, input_size)embed = self.embedding(input)#将输入序列索引转换为嵌入表示#input(seq_len,batch)->embed(seq_len,batch,self.embedding_size)lengths = lengths.data.view(-1).tolist()if lengths is not None:embed = pack_padded_sequence(embed, lengths)#使用pack_padded_sequence对填充的序列进行打包,以便RNN可以跳过填充项output, hn = self.rnn(embed, h0)#将嵌入的序列传递给GRU RNNif lengths is not None:output = pad_packed_sequence(output)[0]#使用pad_packed_sequence对输出序列进行解包,得到RNN的完整输出return hn, output

pytorch 笔记:PAD_PACKED_SEQUENCE 和PACK_PADDED_SEQUENCE-CSDN博客

pytorch笔记:PackedSequence对象送入RNN-CSDN博客

3  Decoder

3.1 init

def __init__(self, input_size, hidden_size, num_layers, dropout, embedding):super(Decoder, self).__init__()self.embedding = embeddingself.rnn = StackingGRUCell(input_size, hidden_size, num_layers,dropout)self.attention = GlobalAttention(hidden_size)self.dropout = nn.Dropout(dropout)self.num_layers = num_layers

 3.2 forward

'''
seq2seq的解码过程,使用了可选的注意力机制
'''
def forward(self, input, h, H, use_attention=True):"""Input:input (seq_len, batch): padded sequence tensorh (num_layers, batch, hidden_size): input hidden stateH (seq_len, batch, hidden_size): the context used in attention mechanismwhich is the output of encoderuse_attention: If True then we use attention---Output:output (seq_len, batch, hidden_size)h (num_layers, batch, hidden_size): output hidden state,h may serve as input hidden state for the next iteration,especially when we feed the word one by one (i.e., seq_len=1)such as in translation"""assert input.dim() == 2, "The input should be of (seq_len, batch)"# (seq_len, batch) => (seq_len, batch, input_size)embed = self.embedding(input)#将输入序列转换为嵌入向量output = []# split along the sequence length dimensionfor e in embed.split(1):#split(1)每次沿着seq_len方法分割一行#即每个e的维度是(1,batch,input_size)e = e.squeeze(0) # (1, batch, input_size) => (batch, input_size)o, h = self.rnn(e, h)#用RNN处理嵌入向量,并得到输出o和新的隐藏状态h#这边的RNN是StackingGRUCell,也即我认为可能是seq_len为1的GRU#o:(batch, hidden_size)#h:(num_layers,batch, hidden_size)if use_attention:o = self.attention(o, H.transpose(0, 1))#如果use_attention为True,将使用注意力机制处理RNN的输出o = self.dropout(o)#为了正则化和防止过拟合,应用 dropoutoutput.append(o)output = torch.stack(output)#将所有的输出叠加为一个张量return output, h#(seq_len, batch, hidden_size)

4 StackingGRUCell

个人感觉就是

class StackingGRUCell(nn.Module):"""Multi-layer CRU Cell"""def __init__(self, input_size, hidden_size, num_layers, dropout):super(StackingGRUCell, self).__init__()self.num_layers = num_layersself.grus = nn.ModuleList()self.dropout = nn.Dropout(dropout)self.grus.append(nn.GRUCell(input_size, hidden_size))for i in range(1, num_layers):self.grus.append(nn.GRUCell(hidden_size, hidden_size))
def forward(self, input, h0):"""Input:input (batch, input_size): input tensorh0 (num_layers, batch, hidden_size): initial hidden state---Output:output (batch, hidden_size): the final layer output tensorhn (num_layers, batch, hidden_size): the hidden state of each layer"""hn = []output = inputfor i, gru in enumerate(self.grus):hn_i = gru(output, h0[i])#在每一次循环中,输入output会经过一个GRU单元并更新隐藏状态hn.append(hn_i)if i != self.num_layers - 1:output = self.dropout(hn_i)else:output = hn_i#如果不是最后一层,输出会经过一个dropout层。hn = torch.stack(hn)#将hn列表转变为一个张量return output, hn

5 GlobalAttention

'''
对于给定的查询向量q,查找上下文矩阵H中哪些向量与其最相关,并使用这些相关性的加权和来生成一个新的上下文向量
'''
class GlobalAttention(nn.Module):"""$$a = \sigma((W_1 q)H)$$$$c = \tanh(W_2 [a H, q])$$"""def __init__(self, hidden_size):super(GlobalAttention, self).__init__()self.L1 = nn.Linear(hidden_size, hidden_size, bias=False)self.L2 = nn.Linear(2*hidden_size, hidden_size, bias=False)self.softmax = nn.Softmax(dim=1)self.tanh = nn.Tanh()def forward(self, q, H):"""Input:q (batch, hidden_size): queryH (batch, seq_len, hidden_size): context---Output:c (batch, hidden_size)"""# (batch, hidden_size) => (batch, hidden_size, 1)q1 = self.L1(q).unsqueeze(2)#使用线性变换L1对查询向量q进行变换,然后增加一个维度以进行后续的批量矩阵乘法# (batch, seq_len)a = torch.bmm(H, q1).squeeze(2)#计算查询向量与上下文矩阵H中的每一个向量的点积。#这将生成一个形状为(batch, seq_len)的张量,表示查询向量与每个上下文向量的相似度a = self.softmax(a)#经过softmax,得到注意力权重# (batch, seq_len) => (batch, 1, seq_len)a = a.unsqueeze(1)#增加一个维度以进行后续的批量矩阵乘法# (batch, hidden_size)c = torch.bmm(a, H).squeeze(1)#使用注意力权重与上下文矩阵H进行加权求和,得到上下文向量c# (batch, hidden_size * 2)c = torch.cat([c, q], 1)#将上下文向量与查询向量连接在一起return self.tanh(self.L2(c))#使用线性变换L2对连接后的向量进行变换,并使用tanh激活函数

相关文章:

论文辅助笔记:t2vec models.py

1 EncoderDecoder 1.1 _init_ class EncoderDecoder(nn.Module):def __init__(self, vocab_size, embedding_size,hidden_size, num_layers, dropout, bidirectional):super(EncoderDecoder, self).__init__()self.vocab_size vocab_size #词汇表大小self.embedding_size e…...

R语言如何写一个爬虫代码模版

R语言爬虫是利用R语言中的网络爬虫包,如XML、RCurl、rvest等,批量自动将网页的内容抓取下来。在进行R语言爬虫之前,需要了解HTML、XML、JSON等网页语言,因为正是通过这些语言我们才能在网页中提取数据。 在爬虫过程中,…...

鸿运主动安全云平台任意文件下载漏洞复习

简介 深圳市强鸿电子有限公司鸿运主动安全监控云平台网页存在任意文件下载漏洞,攻击者可通过此漏洞下载网站配置文件等获得登录账号密码 漏洞复现 FOFA语法:body"./open/webApi.html" 获取网站数据库配置文件 POC:/808gps/Mobile…...

CMake基础【学习笔记(八)】

声明此博客为转载 CMake基础 文章目录 CMake基础一、准备知识1.1 C的编译过程1.2 静态链接库和动态链接库1.3 为什么需要CMake1.3.1 g 命令行编译1.3.2 CMake简介 二、CMake基础知识2.1 安装2.2 第一个CMake例子2.3 语法基础2.3.1 指定版本2.3.2 设置项目2.3.3 添加可执行文件…...

异常的学习

异常分为编译时期异常与运行时期异常 编译时期异常运行前必须处理,否则代码报错 除了RuntimeException和他的子类,其他都是编译时异常 运行时期异常运行时报错,一般是由参数传递错误导致的报错 异常的作用: 1.异常使用来查询b…...

【洛谷 P1101】单词方阵 题解(深度优先搜索)

单词方阵 题目描述 给一 n n n \times n nn 的字母方阵,内可能蕴含多个 yizhong 单词。单词在方阵中是沿着同一方向连续摆放的。摆放可沿着 8 8 8 个方向的任一方向,同一单词摆放时不再改变方向,单词与单词之间可以交叉,因此…...

教师减负神器

在传统的成绩管理模式中,教师需要手动输入、整理、分析成绩数据,工作量大且繁琐。这不仅耗费了教师大量的时间和精力,还容易出现错误。为了解决这个问题,我们可以通过各种代码和Excel来实现学生自助查询成绩的功能。 一、建立成绩…...

Web 开发之前的一些话

我主要是对单页面进行开发,因而VUEFlask的搭配足以满足我的需求; VUE Vue.js - 渐进式 JavaScript 框架 | Vue.js Element-UI Element - The worlds most popular Vue UI framework FLASK 欢迎来到 Flask 的世界 — Flask中文文档(2.3.x)...

git快速入门!!! git的常用命令!!!

git快速入门 git的常用命令1. 初始化一个新的 Git 仓库2. 添加文件到暂存区3. 提交更改4. 查看当前分支的状态5. 创建并切换到新的分支6. 切换回之前的分支7. 合并分支8. 拉取远程仓库的更新9. 推送本地仓库的更新 git remote -v是什么git fetchclone命令详解push指定的分支git…...

C++并发编程实战——01.并发与并行

文章目录 并发并行及其使用原因并发与并行使用与不使用并发的原因C多线程支持 并发并行及其使用原因 本书相关 github翻译地址本书源码下载地址第一版github 翻译地址英文原版PDF不错的笔记所有实例的源代码,可在出版商的网站上进行下载github上下载源码 路线图 …...

PLC如何远程控制、调试?贝锐蒲公英二层组网功能一招搞定

在制造、交通、能源、采矿等领域,工业物联网是热门话题,各类采集、控制器、控制传感器通过网络互联,实现信息实时共享、交互后,不仅能快速了解生产过程数据,还能用于设备远程、调试维护等场景,对优化生产过…...

【大数据】-- flink kubernetes operator 入门与实践

课程链接:https://edu.csdn.net/course/detail/38831 目录 课程链接:https://edu.csdn.net/course/detail/38831https://edu.csdn.net/course/detail/38831 一、你将收获...

网络安全在代理技术中的实现与应用

随着互联网技术的飞速发展,网络安全日益受到人们的重视。在这个背景下,代理技术成为了网络安全实现的重要手段之一。本文将针对 SOCKS5 代理、SK5 代理、IP 代理等代理技术,探讨它们在网络安全和爬虫应用中的重要性,并介绍 HTTP 协…...

Nginx搭配负载均衡和动静分离:构建高性能Web应用的完美组合

目录 前言 一、Nginx简介 1.Nginx是什么 2.Nginx的特点 3.Nginx在哪使用 4.如何使用Nginx 5.Nginx的优缺点 6.Nginx的应用场景 二、负载均衡和动静分离 1.负载均衡 2.动静分离 三、Nginx搭载负载均衡并提供前后端分离后台接口数据 1.Nginx安装 2.tomcat负载均衡 …...

windows 运行 Mysql Command Line Client 自动关闭闪退原因分析

目录 原因分析一 原因分析二 原因分析三 第一次使用 MySQL Command Line Client 有可能输入密码后一按下回车键,程序窗口就自动关闭,出现闪退现象。本节主要分析产生闪退现象的原因以及如何处理这种情况。 原因分析一 首先可以查看程序默认执行文件…...

在CATIA工程制图中自动生成尺寸

然后微调即可...

蓝桥杯 (C++ 求和 等差数列 顺子日期 灌溉)

目录 1、求和 题目: 思路: 代码: 2、等差数列 题目: 思路: 代码: 3、顺子日期 题目: 思路: 代码: 4、灌溉 题目: 代码: 1、求和…...

Spring AOP基于XML方式笔记整理

XML AOP 加载流程 ClassPathXmlApplicationContext#refreshAbstractApplicationContext#obtainFreshBeanFactoryAbstractRefreshableApplicationContext#refreshBeanFactory创建DefaultListableBeanFactoryAbstractApplicationContext#loadBeanDefinitions(beanFactory)创建Xm…...

Docker HTTP(S) Proxy代理方式连接互联网

Docker HTTP(S) Proxy 是一种在 Docker 容器内部设置 HTTP(S) 代理的方法,以便于容器内的应用程序可以方便地通过代理访问互联网。设置 HTTP(S) 代理的方法主要有两种:使用 Dockerfile 配置和在使用 docker run 时添加参数。 以下是使用 Docker HTTP(S) …...

华纳云:centos系统中怎么查看cpu信息?

在CentOS系统中,我们可以使用一些命令来查看CPU的详细信息。下面介绍几个常用的命令: 1. lscpu lscpu命令可以显示CPU的架构、型号、核心数、线程数、频率等信息。 # lscpu 执行以上命令后,会输出类似以下内容: 2. cat /proc/…...

接口测试中缓存处理策略

在接口测试中,缓存处理策略是一个关键环节,直接影响测试结果的准确性和可靠性。合理的缓存处理策略能够确保测试环境的一致性,避免因缓存数据导致的测试偏差。以下是接口测试中常见的缓存处理策略及其详细说明: 一、缓存处理的核…...

19c补丁后oracle属主变化,导致不能识别磁盘组

补丁后服务器重启,数据库再次无法启动 ORA01017: invalid username/password; logon denied Oracle 19c 在打上 19.23 或以上补丁版本后,存在与用户组权限相关的问题。具体表现为,Oracle 实例的运行用户(oracle)和集…...

【Linux】shell脚本忽略错误继续执行

在 shell 脚本中,可以使用 set -e 命令来设置脚本在遇到错误时退出执行。如果你希望脚本忽略错误并继续执行,可以在脚本开头添加 set e 命令来取消该设置。 举例1 #!/bin/bash# 取消 set -e 的设置 set e# 执行命令,并忽略错误 rm somefile…...

YSYX学习记录(八)

C语言&#xff0c;练习0&#xff1a; 先创建一个文件夹&#xff0c;我用的是物理机&#xff1a; 安装build-essential 练习1&#xff1a; 我注释掉了 #include <stdio.h> 出现下面错误 在你的文本编辑器中打开ex1文件&#xff0c;随机修改或删除一部分&#xff0c;之后…...

JVM垃圾回收机制全解析

Java虚拟机&#xff08;JVM&#xff09;中的垃圾收集器&#xff08;Garbage Collector&#xff0c;简称GC&#xff09;是用于自动管理内存的机制。它负责识别和清除不再被程序使用的对象&#xff0c;从而释放内存空间&#xff0c;避免内存泄漏和内存溢出等问题。垃圾收集器在Ja…...

2025 后端自学UNIAPP【项目实战:旅游项目】6、我的收藏页面

代码框架视图 1、先添加一个获取收藏景点的列表请求 【在文件my_api.js文件中添加】 // 引入公共的请求封装 import http from ./my_http.js// 登录接口&#xff08;适配服务端返回 Token&#xff09; export const login async (code, avatar) > {const res await http…...

微服务商城-商品微服务

数据表 CREATE TABLE product (id bigint(20) UNSIGNED NOT NULL AUTO_INCREMENT COMMENT 商品id,cateid smallint(6) UNSIGNED NOT NULL DEFAULT 0 COMMENT 类别Id,name varchar(100) NOT NULL DEFAULT COMMENT 商品名称,subtitle varchar(200) NOT NULL DEFAULT COMMENT 商…...

三体问题详解

从物理学角度&#xff0c;三体问题之所以不稳定&#xff0c;是因为三个天体在万有引力作用下相互作用&#xff0c;形成一个非线性耦合系统。我们可以从牛顿经典力学出发&#xff0c;列出具体的运动方程&#xff0c;并说明为何这个系统本质上是混沌的&#xff0c;无法得到一般解…...

【HTTP三个基础问题】

面试官您好&#xff01;HTTP是超文本传输协议&#xff0c;是互联网上客户端和服务器之间传输超文本数据&#xff08;比如文字、图片、音频、视频等&#xff09;的核心协议&#xff0c;当前互联网应用最广泛的版本是HTTP1.1&#xff0c;它基于经典的C/S模型&#xff0c;也就是客…...

华硕a豆14 Air香氛版,美学与科技的馨香融合

在快节奏的现代生活中&#xff0c;我们渴望一个能激发创想、愉悦感官的工作与生活伙伴&#xff0c;它不仅是冰冷的科技工具&#xff0c;更能触动我们内心深处的细腻情感。正是在这样的期许下&#xff0c;华硕a豆14 Air香氛版翩然而至&#xff0c;它以一种前所未有的方式&#x…...