【HuggingFace Transformers】OpenAIGPTModel源码解析
OpenAIGPTModel源码解析
- 1. GPT 介绍
- 2. OpenAIGPTModel类 源码解析
说到ChatGPT,大家可能都使用过吧。2022年,ChatGPT的推出引发了广泛的关注和讨论。这款对话生成模型不仅具备了强大的语言理解和生成能力,还能进行非常自然的对话,给用户带来了全新的互动体验。然而,ChatGPT的成功背后离不开它的前身——GPT。
1. GPT 介绍
GPT(Generative Pre-trained Transformer)是由OpenAI开发的一种基于Transformer架构的大型语言模型。它由多个堆叠的自注意力解码器层(Transformer Blocks)组成,每一层包含多头自注意力机制和前馈神经网络,并配有残差连接和层归一化以稳定训练。GPT采用自回归方式生成文本,通过在大规模互联网数据上进行预训练,具备强大的自然语言理解和生成能力,能够完成对话生成、文本补全等多种任务。其结构如下:

2. OpenAIGPTModel类 源码解析
源码地址:transformers/src/transformers/models/openai/modeling_openai.py
# -*- coding: utf-8 -*-
# @time: 2024/9/3 20:39
from typing import Optional, Union, Tupleimport torchfrom torch import nn
from transformers import add_start_docstrings, OpenAIGPTPreTrainedModel
from transformers.modeling_outputs import BaseModelOutput
from transformers.models.openai.modeling_openai import OPENAI_GPT_START_DOCSTRING, Block, OPENAI_GPT_INPUTS_DOCSTRING, _CHECKPOINT_FOR_DOC, _CONFIG_FOR_DOC
from transformers.utils import add_start_docstrings_to_model_forward, add_code_sample_docstrings@add_start_docstrings("The bare OpenAI GPT transformer model outputting raw hidden-states without any specific head on top.",OPENAI_GPT_START_DOCSTRING,
)
class OpenAIGPTModel(OpenAIGPTPreTrainedModel):def __init__(self, config):super().__init__(config)self.tokens_embed = nn.Embedding(config.vocab_size, config.n_embd) # 定义 token 嵌入层self.positions_embed = nn.Embedding(config.n_positions, config.n_embd) # 定义 position 嵌入层self.drop = nn.Dropout(config.embd_pdrop) # 定义 drop 层self.h = nn.ModuleList([Block(config.n_positions, config, scale=True) for _ in range(config.n_layer)]) # 定义多个 Block 层# 注册一个缓冲区用于存储position_ids,初始化为从 0 到 config.n_positions 的序列self.register_buffer("position_ids", torch.arange(config.n_positions), persistent=False)# Initialize weights and apply final processingself.post_init()def get_input_embeddings(self):return self.tokens_embeddef set_input_embeddings(self, new_embeddings):self.tokens_embed = new_embeddingsdef _prune_heads(self, heads_to_prune):"""Prunes heads of the model. heads_to_prune: dict of {layer_num: list of heads to prune in this layer}"""# 剪掉模型多头注意力机制中的一些头,heads_to_prune 是一个字典,键为layer_num,值为需要剪枝的 heads 列表。for layer, heads in heads_to_prune.items():self.h[layer].attn.prune_heads(heads)@add_start_docstrings_to_model_forward(OPENAI_GPT_INPUTS_DOCSTRING)@add_code_sample_docstrings(checkpoint=_CHECKPOINT_FOR_DOC,output_type=BaseModelOutput,config_class=_CONFIG_FOR_DOC,)def forward(self,input_ids: Optional[torch.LongTensor] = None,attention_mask: Optional[torch.FloatTensor] = None,token_type_ids: Optional[torch.LongTensor] = None,position_ids: Optional[torch.LongTensor] = None,head_mask: Optional[torch.FloatTensor] = None,inputs_embeds: Optional[torch.FloatTensor] = None,output_attentions: Optional[bool] = None,output_hidden_states: Optional[bool] = None,return_dict: Optional[bool] = None,) -> Union[Tuple[torch.Tensor], BaseModelOutput]:# 根据 config 配置设定 output_attentions, output_hidden_states, return_dict 的值output_attentions = output_attentions if output_attentions is not None else self.config.output_attentionsoutput_hidden_states = (output_hidden_states if output_hidden_states is not None else self.config.output_hidden_states)return_dict = return_dict if return_dict is not None else self.config.use_return_dict# 获取 input_ids 或者 inputs_embeds 以及 input_shapeif input_ids is not None and inputs_embeds is not None: # 当 input_ids 和 inputs_embeds 同时存在时,抛出错误raise ValueError("You cannot specify both input_ids and inputs_embeds at the same time")elif input_ids is not None: # 如果存在 input_ids,将其形状调整为 (batch_size, sequence_length)self.warn_if_padding_and_no_attention_mask(input_ids, attention_mask)input_shape = input_ids.size()input_ids = input_ids.view(-1, input_shape[-1])elif inputs_embeds is not None: # 如果存在 inputs_embeds,获取其形状input_shape = inputs_embeds.size()[:-1]else: # 如果 input_ids 和 inputs_embeds 都不存在,抛出错误raise ValueError("You have to specify either input_ids or inputs_embeds")# 如果没有传入 position_ids,则生成默认的 position_idsif position_ids is None:# Code is different from when we had a single embedding matrix from position and token embeddingsposition_ids = self.position_ids[None, : input_shape[-1]]# ------------------------------------- 1. 获取 attention_mask -----------------------------## Attention mask.if attention_mask is not None:# We create a 3D attention mask from a 2D tensor mask.# Sizes are [batch_size, 1, 1, to_seq_length]# So we can broadcast to [batch_size, num_heads, from_seq_length, to_seq_length]# this attention mask is more simple than the triangular masking of causal attention# used in OpenAI GPT, we just need to prepare the broadcast dimension here.attention_mask = attention_mask.unsqueeze(1).unsqueeze(2) # 将 2D 掩码扩展为 3D 掩码,适用于批量输入# Since attention_mask is 1.0 for positions we want to attend and 0.0 for# masked positions, this operation will create a tensor which is 0.0 for# positions we want to attend and the dtype's smallest value for masked positions.# Since we are adding it to the raw scores before the softmax, this is# effectively the same as removing these entirely.# 将注意力掩码转换为与模型参数相同的数据类型,并进行数值变换,torch.finfo(self.dtype).min 返回数据类型的最小值。attention_mask = attention_mask.to(dtype=next(self.parameters()).dtype) # fp16 compatibilityattention_mask = (1.0 - attention_mask) * torch.finfo(self.dtype).min# ----------------------------------------------------------------------------------------## ------------------------------------- 2. 获取 head_mask ---------------------------------## Prepare head mask if neededhead_mask = self.get_head_mask(head_mask, self.config.n_layer)# ---------------------------------------------------------- -----------------------------## ------------------------------------- 3. 获取 hidden_states -----------------------------## 如果 inputs_embeds 为 None,则使用 tokens_embed 对 input_ids 计算if inputs_embeds is None:inputs_embeds = self.tokens_embed(input_ids)# 计算 position_embedsposition_embeds = self.positions_embed(position_ids)# 如果存在 token_type_ids,使用 tokens_embed 计算;否则 token_type_embeds 为 0if token_type_ids is not None:token_type_ids = token_type_ids.view(-1, token_type_ids.size(-1))token_type_embeds = self.tokens_embed(token_type_ids)else:token_type_embeds = 0# 计算 hidden_states,即inputs_embeds、position_embeds 和 token_type_embeds 之和,并使用 dropouthidden_states = inputs_embeds + position_embeds + token_type_embedshidden_states = self.drop(hidden_states)# -------------------------------------------------------------------------------------## 获取输出形状,以及初始化输出结果 all_attentions 和 all_hidden_statesoutput_shape = input_shape + (hidden_states.size(-1),)all_attentions = () if output_attentions else Noneall_hidden_states = () if output_hidden_states else None# -----------------------------------4. Block逐层计算处理(核心部分)--------------------#for i, block in enumerate(self.h):# 如果需要输出 hidden states,将当前 hidden_states 添加到 all_hidden_statesif output_hidden_states:all_hidden_states = all_hidden_states + (hidden_states,)# 通过当前 Block 处理 hidden_states,得到新的 hidden_states 和 attentionsoutputs = block(hidden_states, attention_mask, head_mask[i], output_attentions=output_attentions)hidden_states = outputs[0]# 如果需要输出 attentions,将当前 attentions 添加到 all_attentionsif output_attentions:all_attentions = all_attentions + (outputs[1],)# ---------------------------------------------------------------------------------## 将 hidden_states 的形状调整为输出形状hidden_states = hidden_states.view(*output_shape)# 如果需要输出 hidden states,将最后的 hidden_states 添加到 all_hidden_statesif output_hidden_states:all_hidden_states = all_hidden_states + (hidden_states,)# -----------------------------------5. 根据配置的输出方式输出结果-------------------------------#if not return_dict:return tuple(v for v in [hidden_states, all_hidden_states, all_attentions] if v is not None)return BaseModelOutput(last_hidden_state=hidden_states,hidden_states=all_hidden_states,attentions=all_attentions,)
相关文章:
【HuggingFace Transformers】OpenAIGPTModel源码解析
OpenAIGPTModel源码解析 1. GPT 介绍2. OpenAIGPTModel类 源码解析 说到ChatGPT,大家可能都使用过吧。2022年,ChatGPT的推出引发了广泛的关注和讨论。这款对话生成模型不仅具备了强大的语言理解和生成能力,还能进行非常自然的对话,…...
macOS安装Java和Maven
安装Java Java Downloads | Oracle 官网下载默认说最新的Java22版本,注意这里我们要下载的是Java8,对应的JDK1.8 需要登陆Oracle,没有账号的可以百度下。账号:908344069qq.com 密码:Java_2024 Java8 jdk1.8配置环境变量 open -e ~/.bash_p…...
SpringBoot教程(安装篇) | Elasticsearch的安装
SpringBoot教程(安装篇) | Elasticsearch的安装 一、确定Elasticsearch版本二、下载elasticsearch(windows版本)官网下载如何解压配置 允许 别人跨域 访问自己启动运行 三、Es可视化工具安装(elasticsearch-head&#…...
前端登录鉴权——以若依Ruoyi前后端分离项目为例解读
权限模型 Ruoyi框架学习——权限管理_若依框架权限-CSDN博客 用户-角色-菜单(User-Role-Menu)模型是一种常用于权限管理的设计模式,用于实现系统中的用户权限控制。该模型主要包含以下几个要素: 用户(User)…...
【Tools】大模型中的自注意力机制
摇来摇去摇碎点点的金黄 伸手牵来一片梦的霞光 南方的小巷推开多情的门窗 年轻和我们歌唱 摇来摇去摇着温柔的阳光 轻轻托起一件梦的衣裳 古老的都市每天都改变模样 🎵 方芳《摇太阳》 自注意力机制(Self-Attention)是一…...
PhotoZoom Classic 9软件新功能特性及安装激活图文教程
PhotoZoom Classic 9这款软件能够对数码图片进行放大,而且放大后的图片没有任何的品质的损坏,没有锯齿,不会失真,如果您有兴趣的话可以试试哦! PhotoZoom Classic 9软件新功能特性 通过屡获殊荣的 S-Spline XL 插值…...
【数据结构】直接插入排序
目录 一、基本思想 二、动图演示 三、思路分析 四、代码实现 五、易错提醒 六、时间复杂度分析 一、基本思想 直接插入排序(Straight Insertion Sort)是一种简单直观的排序算法,其基本思想是: 把待排序的一个记录按其关键码…...
JavaScript 实现虚拟滚动技术
虚拟滚动 虚拟滚动(有时称为 虚拟列表、虚拟滚动条)是 JavaScript 中的一种技术,旨在优化大数据量的列表渲染,尤其是当有成千上万的数据项时,直接渲染整个列表会导致性能问题。虚拟列表通过只渲染用户视口中可见的那一…...
【重学 MySQL】十八、逻辑运算符的使用
【重学 MySQL】十八、逻辑运算符的使用 AND运算符OR运算符NOT运算符异或运算符使用 XOR 关键字使用 BIT_XOR() 函数注意事项 注意事项 在MySQL中,逻辑运算符是构建复杂查询语句的重要工具,它们用于处理布尔类型的数据,进行逻辑判断和组合条件…...
关于 QImage原始数据格式与cv::Mat原始数据进行手码数据转换 的解决方法
若该文为原创文章,转载请注明原文出处 本文章博客地址:https://hpzwl.blog.csdn.net/article/details/141996117 长沙红胖子Qt(长沙创微智科)博文大全:开发技术集合(包含Qt实用技术、树莓派、三维、OpenCV…...
前端WebSocket客户端实现
// 创建WebSocket连接 var socket new WebSocket(ws://your-spring-boot-server-url/websocket-endpoint);// 连接打开时触发 socket.addEventListener(open, function (event) {socket.send(JSON.stringify({type: JOIN, room: general})); });// 监听从服务器来的消息 socke…...
读取realsense d455双目及imu
问题定义 实时读取realsense数据喂给slam系统 代码 /** rs_d455设备 */#include <librealsense2/rs.hpp> #include <iostream>#include "rs_common_device.h"// opencv #include <opencv2/opencv.hpp>class RsD455Device: public rsCmmonDevice…...
浮点的运算
浮点数表示: N 尾数 * 基数指数 1.25 X 106 尾数一般用补码,指数一般用移码 在IEEE745中尾数可以是原码。 尾数可以表示数值的有效精度,位数越多精度越高 阶码的位数决定数的表示范围,位数越多,范围越大 对阶时&…...
对随机游走问题的分析特定行为模式的建模
从一段随机游走的数据中寻找特定的行为模式,这种问题涉及 序列模式识别 或 序列分析。处理这种问题的算法选择取决于你要找的模式的具体性质和复杂性。以下是几种可能的算法: 隐马尔可夫模型(HMM) 隐马尔可夫模型特别适合处理随…...
JVM面试(七)G1垃圾收集器剖析
概述 上一章我们说了,G1收集器,它属于里程碑式的发展,开创了面向局部收集垃圾的概念。专门针对多核处理器以及大内存的机器。在JDK9中,更是呗指定为官方的GC收集器。满足高吞吐的通知满足GC的STW停顿时间尽可能的短。 虽然现在我…...
php转职golang第一期
入局golang 基础语法:学习 Go 语言的基本语法、数据类型、流程控制等。 数据结构与算法:掌握常用的数据结构和算法。 Web 开发基础:了解 HTTP 协议、Web 开发的基本概念。 Gin 框架或其他 Web 框架:深入学习使用一种 Go 的 Web…...
java后端服务监控与告警:Prometheus与Grafana集成
Java后端服务监控与告警:Prometheus与Grafana集成 大家好,我是微赚淘客返利系统3.0的小编,是个冬天不穿秋裤,天冷也要风度的程序猿! 在现代的微服务架构中,监控和告警是确保服务稳定性的关键组成部分。Pr…...
【系统架构设计师】工厂方法设计模式
工厂方法(Factory Method)模式是一种创建型设计模式,它定义了一个用于创建对象的接口,但让子类决定要实例化的类是哪一个。工厂方法让类的实例化延迟到子类中进行。 工厂方法模式的主要角色 产品(Product):定义工厂的创建对象的接口。具体产品(Concrete Product):实…...
怎样解决OpenEuler下载sdl2失败
OpenEuler 下载 sdl2失败 解决办法(使用wget中git上下载) wget https://github.com/libsdl-org/SDL/releases/download/release-2.30.6/SDL2-2.30.6.tar.gz使用yum下载,下载的最后说找不到这样的库(no match)使用 apt-get,说找不到apt-get使用curl冲gi…...
基于Python的自然语言处理系列(2):Word2Vec(负采样)
在本系列的第二篇文章中,我们将继续探讨Word2Vec模型,这次重点介绍负采样(Negative Sampling)技术。负采样是一种优化Skip-gram模型训练效率的技术,它能在大规模语料库中显著减少计算复杂度。接下来,我们将…...
Real-ESRGAN-GUI:如何用AI双引擎将模糊图片一键变高清
Real-ESRGAN-GUI:如何用AI双引擎将模糊图片一键变高清 【免费下载链接】Real-ESRGAN-GUI Lovely Real-ESRGAN / Real-CUGAN GUI Wrapper 项目地址: https://gitcode.com/gh_mirrors/re/Real-ESRGAN-GUI 还在为模糊的老照片、低分辨率的动漫图片而烦恼吗&…...
Flink源码阅读:双流操作
Window Join我们先回顾一下 window join 的使用方法。DataStream<Tuple2<String, Double>> result source1.join(source2).where(record -> record.f0).equalTo(record -> record.f0).window(TumblingEventTimeWindows.of(Time.seconds(2L))).apply(new Joi…...
嵌入式开发必备:三大代码对比工具深度评测
1. 代码对比工具概述作为一名嵌入式开发工程师,我每天都要处理大量的代码修改和版本对比工作。在多年的开发实践中,我发现选择合适的代码对比工具能极大提升工作效率。虽然Beyond Compare是业内公认的标杆产品,但实际工作中我们还有更多选择&…...
用Simulink+Carsim复现论文:四轮转向后轮控制5种算法对比(附模型下载)
用SimulinkCarsim复现论文:四轮转向后轮控制5种算法对比(附模型下载) 在车辆动力学与控制领域,四轮转向技术正逐渐从豪华车型向主流市场渗透。不同于传统的前轮转向系统,四轮转向通过后轮主动参与转向,显著…...
3.多表关联在电商数据分析中的核心价值
多表关联在电商数据分析中的核心价值 第1章 多表关联、子查询与行列转换在电商数据分析中的核心价值 1.1 为什么单表查询不够用 我刚开始做数据分析的时候,以为SQL就是在一张表上做筛选和汇总。直到有一天,运营问我:“这批高价值用户…...
告别照相馆!AI头像生成器教你免费制作高质量职业头像
告别照相馆!AI头像生成器教你免费制作高质量职业头像 1. 为什么选择AI生成职业头像? 在当今数字化求职环境中,一张专业的头像照片已经成为简历不可或缺的部分。传统照相馆拍摄存在三个主要痛点: 成本高昂:专业摄影工…...
Wan2.2-I2V-A14B文生视频模型落地实践:单卡4090D高效推理部署案例
Wan2.2-I2V-A14B文生视频模型落地实践:单卡4090D高效推理部署案例 1. 项目背景与价值 视频内容创作正成为数字时代的重要需求,但传统视频制作流程复杂、成本高昂。Wan2.2-I2V-A14B作为新一代文生视频模型,能够直接将文本描述转化为高质量视…...
电商评论分析利器:GTE文本向量实战情感分析与产品问题挖掘
电商评论分析利器:GTE文本向量实战情感分析与产品问题挖掘 1. 电商评论分析的痛点与解决方案 电商平台每天产生海量用户评论,这些评论蕴含着消费者真实的产品体验和市场反馈。传统的人工分析方法面临三大挑战: 处理效率低:人工…...
MaxKB:企业级AI知识库部署实战指南
MaxKB:企业级AI知识库部署实战指南 【免费下载链接】MaxKB 🔥 MaxKB is an open-source platform for building enterprise-grade agents. 强大易用的开源企业级智能体平台。 项目地址: https://gitcode.com/GitHub_Trending/ma/MaxKB 面对企业AI…...
告别重复劳动:用快马AI智能生成OpenCode风格的高效工具函数
最近在开发一个需要大量表单验证的项目时,我发现每次都要重复写类似的验证逻辑,既浪费时间又容易出错。于是我开始寻找更高效的解决方案,最终在InsCode(快马)平台上找到了理想的工具。 需求分析 表单验证是每个Web项目都绕不开的基础功能。常…...
