当前位置: 首页 > news >正文

简单的二元语言模型bigram实现

  • 内容总结归纳自视频:【珍藏】从头开始用代码构建GPT - 大神Andrej Karpathy 的“神经网络从Zero到Hero 系列”之七_哔哩哔哩_bilibili 
  • 项目:https://github.com/karpathy/ng-video-lecture

Bigram模型是基于当前Token预测下一个Token的模型。例如,如果输入序列是`[A, B, C]`,那么模型会根据`A`预测`B`,根据`B`预测`C`,依此类推,实现自回归生成。在生成新Token时,通常只需要最后一个Token的信息,因为每个预测仅依赖于当前Token。

1. 训练batch数据形式

训练数据是:

 训练目标是:

2. 定义词嵌入层

nn.Embedding 层输出的是可学习的浮点数,将token索引 (B,T) 直接映射为logits,即输入(4,8),输出 (4,8,65),其中输入每个数字,被映射成logit向量(这些值通过 F.cross_entropy 内部自动进行 softmax 转换为概率分布),比如上面输入tokens有个24被映射成如下。

logits = [1.0, 0.5, -2.0, ..., 3.2]  # 共65个浮点数

softmax后得到。

probs = [0.15, 0.12, 0.01, ..., 0.20]  # 和为1的概率分布

这样输出的是每个位置的概率分布。

交叉熵函数会自动计算每个位置的概率分布与真实标签之间的损失,并取平均。

简单的大语言模型,基于Bigram的结构,即每个token仅根据前一个token来预测下一个token。具体实现如下。

from torch.nn import functional as F  # 导入PyTorch函数模块
torch.manual_seed(1337)  # 固定随机种子保证结果可复现class BigramLanguageModel(nn.Module):  # 定义Bigram语言模型类def __init__(self, vocab_size):super().__init__()  # 继承父类初始化方法# 定义词嵌入层:将token索引直接映射为logits# 输入输出维度均为vocab_size(词汇表大小)self.token_embedding_table = nn.Embedding(vocab_size, vocab_size)def forward(self, idx, targets):# 前向传播函数# idx: inputs, 输入序列 (B, T),B=批次数,T=序列长度# targets: 目标序列 (B, T)# (4,8) -> (4,8,65)logits = self.token_embedding_table(idx)  # (B, T, C)# 通过嵌入层获得每个位置的概率分布,C=词汇表大小# (4*8,C), (4*8,) -> (1,)B, T, C = logits.shape  # 解包维度:批次数、序列长度、词表大小logits = logits.view(B*T, C)  # 展平为二维张量 (B*T, C)targets = targets.view(B*T)    # 目标展平为一维张量 (B*T)loss = F.cross_entropy(logits, targets)  # 计算交叉熵损失return logits, loss  # 返回logits(未归一化概率)和损失值# 假设 vocab_size=65(例如52字母+标点)
vocab_size = 65
m = BigramLanguageModel(vocab_size)  # 实例化模型# 假设输入数据(代码中未定义):
# xb: 输入批次 (B=4, T=8),例如 tensor([[1,2,3,...], ...])
# yb: 目标批次 (B=4, T=8)
logits, loss = m(xb, yb)  # 执行前向传播print(logits.shape)  # 输出logits形状:torch.Size([32, 65])
# 解释:32 = B*T = 4*8,65=词表大小(每个位置65种可能)print(loss)  # 输出损失值:tensor(4.8786, grad_fn=<NllLossBackward>)
# 解释:初始随机参数下,损失值约为-ln(1/65)=4.17,实际值因参数初始化略有波动

3. 代码逻辑分步解释

# 假设输入和目标的形状均为 (B=4, T=8)
# 输入示例(第一个样本):
inputs[0] = [24, 43, 58, 5, 57, 1, 46, 43]
targets[0] = [43, 58, 5, 57, 1, 46, 43, 39]

3.1 Softmax后的概率分布意义 

当模型处理输入序列时,每个位置会输出一个长度为vocab_size的logits向量。即

输入: (4,8);

输出:(4,8,65). 65维度向量是每个输入token的下一个token的概率分布。

例如,当输入序列为 [24, 43, 58, 5, 57, 1, 46, 43] 时:
- 在第1个位置(token=24),模型预测下一个token(对应target=43)的概率分布p[0].shape=(65,),下一个输出是43的概率为p[0][target[0]]=p[0][43];
- 在第2个位置(token=43),模型预测下一个token(对应target=58)的概率分布p[1].shape=(65),下一个输出是58的概率为p[1][target[1]]=p[1][58];
- 以此类推,每个位置的logits经过softmax后得到一个概率分布,即每个输入位置,都会预测下一个token概率分布
具体来说:
   logits.shape = (4, 8, 65) → softmax后形状不变->p.shape=(4,8,65),但每行的65个值变为概率(和为1)
这些概率表示模型认为「当前token的下一个token」是词汇表中各token的可能性。

3.2 交叉熵计算步骤

假设logits初始形状为 (4, 8, 65)
B, T, C = logits.shape  # B=4, T=8, C=65

# 展平logits和targets:
logits_flat = logits.view(B*T, C)  # 形状 (32, 65)
targets_flat = targets.view(B*T)    # 形状 (32,)

# 交叉熵计算(PyTorch内部过程):
# 对logits_flat的每一行(共32行)做softmax,得到概率分布probs (32, 65)
# 对每个样本i,取probs[i][targets_flat[i]],即真实标签对应的预测概率(此概率是下一个token是targets_flat[i]的概率
# 计算负对数损失:loss = -mean(log(probs[i][targets_flat[i]]))(pytorch实现是将targets_flat所谓索引
loss = F.cross_entropy(logits_flat, targets_flat)  # 输出标量值

3.3 示例计算

# 以第一个样本的第一个位置为例:
# 输入token=24,目标token=43
# 模型输出的logits[0,0]是一个65维向量(这里logits.shape=[4,8,65]),例如:
logits_example = logits[0,0]  # 形状 (65,)
probs_example = F.softmax(logits_example, dim=-1)  # 形状 (65,)
# 假设probs_example[43] = 0.15(模型预测下一个token=43的概率为15%)
# 则此位置的损失为 -log(0.15) ≈ 1.897 (注意-log(p)是一个x范围在[0,1]之间单调递减函数)

# 最终损失是所有32个位置类似计算的均值。
# 初始损失约为4.87(接近均匀分布的理论值 -ln(1/65)≈4.17)

4. 测试生成文本

# super simple bigram model
class BigramLanguageModel(nn.Module):def __init__(self, vocab_size):super().__init__()# each token directly reads off the logits for the next token from a lookup tableself.token_embedding_table = nn.Embedding(vocab_size, vocab_size)def forward(self, idx, targets=None):# idx and targets are both (B,T) tensor of integerslogits = self.token_embedding_table(idx) # (B,T,C)if targets is None:loss = Noneelse:B, T, C = logits.shapelogits = logits.view(B*T, C)targets = targets.view(B*T)loss = F.cross_entropy(logits, targets)return logits, lossdef generate(self, idx, max_new_tokens):# idx is (B, T) array of indices in the current contextfor _ in range(max_new_tokens):# get the predictionslogits, loss = self(idx) # 没有输入target时,返回的logits未被展平。  # focus only on the last time steplogits = logits[:, -1, :] # (B,T,C) -> (B, C)# apply softmax to get probabilitiesprobs = F.softmax(logits, dim=-1) # (B, C)# sample from the distributionidx_next = torch.multinomial(probs, num_samples=1) # (B, 1)# append sampled index to the running sequenceidx = torch.cat((idx, idx_next), dim=1) # (B, T+1)return idx

以下是Bigram模型生成过程的逐步详解,以输入序列[24, 43, 58, 5, 57, 1, 46, 43]为例,说明模型如何从初始输入[24]开始逐步预测下一个词: 

1. 初始输入:[24]

  • 输入形状idx = [[24]]B=1批次,T=1序列长度)。

  • 前向传播

    • 通过嵌入层,模型输出logits形状为(1, 1, 65),表示对当前词24的下一个词的预测分数。

    • 假设logits[0, 0, 43] = 5.0(词43的logit较高),其他位置logits较低(如logits[0, 0, :] = [..., 5.0, ...])。

  • 概率分布

    • 对logits应用softmax,得到概率分布probs。例如:

      probs = [0.01, ..., 0.8(对应43), 0.01, ...]  # 总和为1
  • 采样

    • 根据probs,使用torch.multinomial采样,选中词43的概率最大。

  • 更新输入

    • 43拼接到序列末尾,新输入为idx = [[24, 43]](形状(1, 2))。


2. 输入:[24, 43]

  • 前向传播

    • 模型处理整个序列,输出logits形状为(1, 2, 65),对应两个位置的预测:

      • 第1个位置(词24)预测下一个词(已生成43)。

      • 第2个位置(词43)预测下一个词。

    • 提取最后一个位置的logits:logits[:, -1, :](形状(1, 65))。

    • 假设logits[0, -1, 58] = 6.0(词58的logit较高)。

  • 概率分布

    • probs = [0.01, ..., 0.85(对应58), 0.01, ...]

  • 采样

    • 选中词58

  • 更新输入

    • 新输入为idx = [[24, 43, 58]](形状(1, 3))。


3. 输入:[24, 43, 58]

  • 前向传播

    • logits形状为(1, 3, 65)

    • 提取最后一个位置(词58)的logits,假设logits[0, -1, 5] = 4.5

  • 概率分布

    • probs = [0.01, ..., 0.7(对应5), ...]

  • 采样

    • 选中词5

  • 更新输入

    • 新输入为idx = [[24, 43, 58, 5]](形状(1, 4))。


4. 重复生成直到序列完成

  • 后续步骤

    • 输入[24, 43, 58, 5] → 预测词57

    • 输入[24, 43, 58, 5, 57] → 预测词1

    • 输入[24, 43, 58, 5, 57, 1] → 预测词46

    • 输入[24, 43, 58, 5, 57, 1, 46] → 预测词43

  • 最终序列

    • idx = [[24, 43, 58, 5, 57, 1, 46, 43]]

注意:上面输入序列是越来越长的,为何说预测下一个词只跟上一个词有关?如果只跟一个词有关,为何不每次只输入一个词,然后预测下一个词?

虽然理论上可以仅传递最后一个词,但实际实现中传递完整序列的原因(视频作者说的,固定generate函数形式,我这里理解的是代码简洁):

  • 代码简洁性:无需在每次生成时截取最后一个词,直接复用统一的前向传播逻辑;

实验验证

若修改代码,每次仅传递最后一个词:

def generate(self, idx, max_new_tokens):for _ in range(max_new_tokens):last_token = idx[:, -1:]          # 仅取最后一个词 (B, 1)logits, _ = self(last_token)       # 输出形状 (B, 1, C)probs = F.softmax(logits[:, -1, :], dim=-1)idx_next = torch.multinomial(probs, num_samples=1)idx = torch.cat((idx, idx_next), dim=1)return idx

相关文章:

简单的二元语言模型bigram实现

内容总结归纳自视频&#xff1a;【珍藏】从头开始用代码构建GPT - 大神Andrej Karpathy 的“神经网络从Zero到Hero 系列”之七_哔哩哔哩_bilibili 项目&#xff1a;https://github.com/karpathy/ng-video-lecture Bigram模型是基于当前Token预测下一个Token的模型。例如&#x…...

【清华大学】实用DeepSeek赋能家庭教育 56页PDF文档完整版

清华大学-56页&#xff1a;实用DeepSeek赋能家庭教育.pdf https://pan.baidu.com/s/1BUweVDeG2M8-t0QaIs3LHQ?pwd1234 提取码: 1234 或 https://pan.quark.cn/s/8a9473493bb0 《实用DeepSeek赋能家庭教育》基于清华大学研究成果&#xff0c;系统阐述了DeepSeek人工智能技…...

黑洞如何阻止光子逃逸

虽然涉及广义相对论&#xff0c;但广义相对论说的是大质量物体对周围空间的影响&#xff0c;而不是说周围空间和空间中的光子之间的关系。也就是说&#xff0c;若讨论光子逃逸问题&#xff0c;则不必限定于大质量的前提&#xff0c;也就是说&#xff0c;若质量周围被扭曲的空间…...

1.4 单元测试与热部署

本次实战实现Spring Boot的单元测试与热部署功能。单元测试方面&#xff0c;通过JUnit和Mockito等工具&#xff0c;结合SpringBootTest注解&#xff0c;可以模拟真实环境对应用组件进行独立测试&#xff0c;验证逻辑正确性&#xff0c;提升代码质量。具体演示了HelloWorld01和H…...

window系统中的start命令详解

start 是 Windows 系统中用于启动新进程或打开新窗口来运行指定程序或命令的命令。以下是对 start 命令参数的详细解释&#xff1a; 基本语法 start ["title"] [/Dpath] [/I] [/MIN] [/MAX] [/SEPARATE | /SHARED] [/LOW | /NORMAL | /HIGH | /REALTIME | /ABOVENO…...

AI编程工具节选

1、文心快码 百度基于文心大模型推出的一款智能编码助手&#xff0c; 官网地址&#xff1a;文心快码(Baidu Comate)更懂你的智能代码助手 2、通义灵码 阿里云出品的一款基于通义大模型的智能编码辅助工具&#xff0c; 官网地址&#xff1a;通义灵码_你的智能编码助手-阿里云 …...

正则表达式,idea,插件anyrule

​​​​package lx;import java.util.regex.Pattern;public class lxx {public static void main(String[] args) {//正则表达式//写一个电话号码的正则表达式String regex "1[3-9]\\d{9}";//第一个数字是1&#xff0c;第二个数字是3-9&#xff0c;后面跟着9个数字…...

原生iOS集成react-native (react-native 0.65+)

由于官方文档比较老&#xff0c;很多配置都不能用&#xff0c;集成的时候遇到很多坑&#xff0c;简单的整理一下 时间节点:2021年9月1日 本文主要提供一些配置信息以及错误信息解决方案&#xff0c;具体步骤可以参照官方文档 原版文档&#xff1a;https://reactnative.dev/docs…...

java错题总结

本篇文章用来记录学习javaSE以来的错题 解答&#xff1a;重载要求俩个方法的名字相同&#xff0c;但参数的类型或者个数不同&#xff0c;但是不要求返回类型相同&#xff0c;所以A正确。 重写还需要要求返回类型相同&#xff08;呈现父子类关系也可以&#xff0c;但是属于特例&…...

【商城实战(10)】解锁商品信息录入与展示的技术密码

【商城实战】专栏重磅来袭&#xff01;这是一份专为开发者与电商从业者打造的超详细指南。从项目基础搭建&#xff0c;运用 uniapp、Element Plus、SpringBoot 搭建商城框架&#xff0c;到用户、商品、订单等核心模块开发&#xff0c;再到性能优化、安全加固、多端适配&#xf…...

2025年主流原型工具测评:墨刀、Axure、Figma、Sketch

2025年主流原型工具测评&#xff1a;墨刀、Axure、Figma、Sketch 要说2025年国内产品经理使用的主流原型设计工具&#xff0c;当然是墨刀、Axure、Figma和Sketch了&#xff0c;但是很多刚入行的产品经理不了解自己适合哪些工具&#xff0c;本文将从核心优势、局限短板、协作能…...

MDM 如何彻底改变医疗设备的远程管理

在现代医疗行业迅速发展的格局中&#xff0c;医院和诊所越来越依赖诸如医疗平板和移动工作站等移动设备。这些设备在提高工作效率和提供卓越的患者护理方面发挥着关键作用。然而&#xff0c;随着它们的广泛使用&#xff0c;也带来了一系列挑战&#xff0c;例如在不同地点确保数…...

OpenCV计算摄影学(18)平滑图像中的纹理区域同时保留边缘信息函数textureFlattening()

操作系统&#xff1a;ubuntu22.04 OpenCV版本&#xff1a;OpenCV4.9 IDE:Visual Studio Code 编程语言&#xff1a;C11 算法描述 cv::textureFlattening 是 OpenCV 中用于图像处理的一个函数&#xff0c;旨在平滑图像中的纹理区域&#xff0c;同时保留边缘信息。该技术特别适…...

用DeepSeek学Android开发:Android初学者遇到的常见问题有哪些?如何解决?

答案来自 DeepSeek Q: Android初学者遇到的常见问题有哪些&#xff1f;如何解决&#xff1f; A: Android初学者在学习过程中常会遇到以下问题及对应的解决方法&#xff0c;按类别整理如下&#xff1a; 一、开发环境问题 Android Studio安装或配置问题 问题&#xff1a;安装失…...

springboot 集成 MongoDB 基础篇

demo架构&#xff1a; Book Controller&#xff1a; package com.zy.controller;import com.zy.entity.Book; import com.zy.service.MongoDbService; import org.springframework.beans.factory.annotation.Autowired; import org.springframework.web.bind.annotation.Get…...

大白话html语义化标签优势与应用场景

大白话html语义化标签优势与应用场景 大白话解释 语义化标签就是那些名字能让人一看就大概知道它是用来做什么的标签。以前我们经常用<div>来做各种布局&#xff0c;但是<div>本身没有什么实际的含义&#xff0c;就像一个没有名字的盒子。而语义化标签就像是有名…...

恶劣天候三维目标检测论文列表整理

恶劣天候三维目标检测论文列表 图摘自Kradar &#x1f3e0; 介绍 Hi&#xff0c;这是有关恶劣天气下三维目标检测的论文列表。主要是来源于近3年研究过程中认为有意义的文章。希望能为新入门的研究者提供一些帮助。 可能比较简陋&#xff0c;存在一定的遗漏&#xff0c;欢迎…...

conda的环境起的jupyter用不了已经安装的包如何解决

当你在使用Conda环境中的Jupyter Notebook时遇到无法读取某些库或模块的问题&#xff0c;通常是由以下几个原因引起的&#xff1a; 环境未激活&#xff1a;确保你已经在正确的Conda环境中激活了Jupyter Notebook。 库未安装&#xff1a;可能你需要的库没有在当前的Conda环境中…...

蓝桥杯题型

蓝桥杯题型分类 二分 123 传送门 1. 小区间的构成 假设数列的构成是如下形式&#xff1a; 第 1 个区间包含 1 个元素&#xff08;1&#xff09;。第 2 个区间包含 2 个元素&#xff08;1 2&#xff09;。第 3 个区间包含 3 个元素&#xff08;1 2 3&#xff09;。第 4 个区…...

STM32-I2C通信协议

一&#xff1a;I2C通信协议 就是在串口通信上满足四个要求 要求1&#xff1a;删掉一根通信线&#xff0c;防止资源浪费&#xff0c;只能在同一根线上进行发送和接收要求2&#xff1a;需要一个应答机制&#xff0c;没发送一个字节都有一次应答要求3&#xff1a;一根线上能同时…...

taosd 写入与查询场景下压缩解压及加密解密的 CPU 占用分析

在当今大数据时代&#xff0c;时序数据库的应用越来越广泛&#xff0c;尤其是在物联网、工业监控、金融分析等领域。TDengine 作为一款高性能的时序数据库&#xff0c;凭借独特的存储架构和高效的压缩算法&#xff0c;在存储和查询效率上表现出色。然而&#xff0c;随着数据规模…...

uniapp微信小程序vue3自定义tabbar

在App.vue隐藏原生tabbar&#xff0c;也可以在pages.json中配置 二选一就好了 创建 CustomTabBar 公共组件 <template><view class"custom-tab-bar" :style"{paddingBottom: safeAreaHeight px}"><view class"tab-bar-item" :…...

BUUCTF——[GYCTF2020]FlaskApp1 SSTI模板注入/PIN学习

目录 一、网页功能探索 二、SSTI注入 三、方法一 四、方法二 使用PIN码 &#xff08;1&#xff09;服务器运行flask登录所需的用户名 &#xff08;2&#xff09;modename &#xff08;3&#xff09;flask库下app.py的绝对路径 &#xff08;4&#xff09;当前网络的mac地…...

如何用Kimi生成PPT?秒出PPT更高效!

做PPT是不是总是让你头疼&#xff1f;&#x1f629; 快速制作出专业的PPT&#xff0c;今天我们要推荐两款超级好用的AI工具——Kimi 和 秒出PPT&#xff01;我们来看看哪一款更适合你吧&#xff01;&#x1f680; &#x1f947; Kimi&#xff1a;让PPT制作更轻松 Kimi的生成效…...

数据结构(回顾)

数据结构&#xff08;回顾&#xff09; 回顾 不同点顺序表链表存储空间上物理上一定连续逻辑上连续&#xff0c;物理上不一定连续随机访问支持&#xff0c;时间复杂度O(1)不支持&#xff0c;时间复杂度O(N)任意位置插入或者删除元素可能需要挪动元素&#xff0c;效率低&#…...

全国产!瑞芯微3562Mini(2GHz四核A53 NPU)工业开发板规格书

评估板简介 创龙科技 TL3562-MiniEVM 是一款基于瑞芯微 RK3562J/RK3562 处理器设计的四核 AR M Cortex-A53 单核 ARM Cortex-M0 国产工业评估板&#xff0c;主频高达 2.0GHz。评估板由核心板和评估底板组成&#xff0c;核心板 CPU、ROM、RAM、电源、晶振等所有元器件均采用国…...

鸿蒙HarmonyOS评论功能小demo

评论页面小demo 效果展示 1.拆解组件&#xff0c;分层搭建 我们将整个评论页面拆解为三个组件&#xff0c;分别是头部导航&#xff0c;评论项&#xff0c;回复三个部分&#xff0c;然后统一在index界面导入 2.头部导航界面搭建 Preview Component struct HmNavBar {// 属性&a…...

异常(6)

今天我们继续来讲异常的内容,关于异常的捕获和声明,也是在处理异常的的重要方式,话不多说,来看. 异常的捕获 异常的捕获,也就是异常,的具体处理方式,主要有两种,主要有两种&#xff1a;异常声明throws以及try-catch捕获处理. 3.1异常声明throws. 处在方法声明时参数列表之后…...

精选一百道备赛蓝桥杯——2.K倍区间

解题思路 任何两个前缀区间的和对k取模的值相等&#xff0c;则由大的前缀区间减掉小的前缀区间所形成的区间的必定是K倍区间。因此我们可以对具有区间和%k值相等任何两个区间进行组合&#xff0c;再将这些值加起来就得到结果&#xff01;证明&#xff1a; 假设一个数列为a1,a2…...

编译Telegram Desktop

目录 一、前言 二、环境准备 2.1 2.2 2.3 2.4 2.5 2.6 2.7 2.8 三、编译 四、总结和学习 一、前言 Telegram 是一款全球广泛使用的即时通讯软件&#xff0c;以其强大的隐私保护、跨平台同步和丰富的功能而闻名。它支持一对一聊天、群组&#xff08;最多20万成员&am…...