简单的二元语言模型bigram实现
- 内容总结归纳自视频:【珍藏】从头开始用代码构建GPT - 大神Andrej Karpathy 的“神经网络从Zero到Hero 系列”之七_哔哩哔哩_bilibili
 - 项目:https://github.com/karpathy/ng-video-lecture
 
Bigram模型是基于当前Token预测下一个Token的模型。例如,如果输入序列是`[A, B, C]`,那么模型会根据`A`预测`B`,根据`B`预测`C`,依此类推,实现自回归生成。在生成新Token时,通常只需要最后一个Token的信息,因为每个预测仅依赖于当前Token。
1. 训练batch数据形式
训练数据是:

训练目标是:

2. 定义词嵌入层
nn.Embedding 层输出的是可学习的浮点数,将token索引 (B,T) 直接映射为logits,即输入(4,8),输出 (4,8,65),其中输入每个数字,被映射成logit向量(这些值通过 F.cross_entropy 内部自动进行 softmax 转换为概率分布),比如上面输入tokens有个24被映射成如下。
logits = [1.0, 0.5, -2.0, ..., 3.2] # 共65个浮点数
softmax后得到。
probs = [0.15, 0.12, 0.01, ..., 0.20] # 和为1的概率分布
这样输出的是每个位置的概率分布。
交叉熵函数会自动计算每个位置的概率分布与真实标签之间的损失,并取平均。
简单的大语言模型,基于Bigram的结构,即每个token仅根据前一个token来预测下一个token。具体实现如下。
from torch.nn import functional as F  # 导入PyTorch函数模块
torch.manual_seed(1337)  # 固定随机种子保证结果可复现class BigramLanguageModel(nn.Module):  # 定义Bigram语言模型类def __init__(self, vocab_size):super().__init__()  # 继承父类初始化方法# 定义词嵌入层:将token索引直接映射为logits# 输入输出维度均为vocab_size(词汇表大小)self.token_embedding_table = nn.Embedding(vocab_size, vocab_size)def forward(self, idx, targets):# 前向传播函数# idx: inputs, 输入序列 (B, T),B=批次数,T=序列长度# targets: 目标序列 (B, T)# (4,8) -> (4,8,65)logits = self.token_embedding_table(idx)  # (B, T, C)# 通过嵌入层获得每个位置的概率分布,C=词汇表大小# (4*8,C), (4*8,) -> (1,)B, T, C = logits.shape  # 解包维度:批次数、序列长度、词表大小logits = logits.view(B*T, C)  # 展平为二维张量 (B*T, C)targets = targets.view(B*T)    # 目标展平为一维张量 (B*T)loss = F.cross_entropy(logits, targets)  # 计算交叉熵损失return logits, loss  # 返回logits(未归一化概率)和损失值# 假设 vocab_size=65(例如52字母+标点)
vocab_size = 65
m = BigramLanguageModel(vocab_size)  # 实例化模型# 假设输入数据(代码中未定义):
# xb: 输入批次 (B=4, T=8),例如 tensor([[1,2,3,...], ...])
# yb: 目标批次 (B=4, T=8)
logits, loss = m(xb, yb)  # 执行前向传播print(logits.shape)  # 输出logits形状:torch.Size([32, 65])
# 解释:32 = B*T = 4*8,65=词表大小(每个位置65种可能)print(loss)  # 输出损失值:tensor(4.8786, grad_fn=<NllLossBackward>)
# 解释:初始随机参数下,损失值约为-ln(1/65)=4.17,实际值因参数初始化略有波动 
3. 代码逻辑分步解释
# 假设输入和目标的形状均为 (B=4, T=8)
 # 输入示例(第一个样本):
 inputs[0] = [24, 43, 58, 5, 57, 1, 46, 43]
 targets[0] = [43, 58, 5, 57, 1, 46, 43, 39]
3.1 Softmax后的概率分布意义

当模型处理输入序列时,每个位置会输出一个长度为vocab_size的logits向量。即
输入: (4,8);
输出:(4,8,65). 65维度向量是每个输入token的下一个token的概率分布。
例如,当输入序列为 [24, 43, 58, 5, 57, 1, 46, 43] 时:
 - 在第1个位置(token=24),模型预测下一个token(对应target=43)的概率分布p[0].shape=(65,),下一个输出是43的概率为p[0][target[0]]=p[0][43];
 - 在第2个位置(token=43),模型预测下一个token(对应target=58)的概率分布p[1].shape=(65),下一个输出是58的概率为p[1][target[1]]=p[1][58];
 - 以此类推,每个位置的logits经过softmax后得到一个概率分布,即每个输入位置,都会预测下一个token概率分布。
 具体来说:
    logits.shape = (4, 8, 65) → softmax后形状不变->p.shape=(4,8,65),但每行的65个值变为概率(和为1)
 这些概率表示模型认为「当前token的下一个token」是词汇表中各token的可能性。
3.2 交叉熵计算步骤

假设logits初始形状为 (4, 8, 65)
 B, T, C = logits.shape  # B=4, T=8, C=65
# 展平logits和targets:
 logits_flat = logits.view(B*T, C)  # 形状 (32, 65)
 targets_flat = targets.view(B*T)    # 形状 (32,)

# 交叉熵计算(PyTorch内部过程):
 # 对logits_flat的每一行(共32行)做softmax,得到概率分布probs (32, 65)
 # 对每个样本i,取probs[i][targets_flat[i]],即真实标签对应的预测概率(此概率是下一个token是targets_flat[i]的概率)
 # 计算负对数损失:loss = -mean(log(probs[i][targets_flat[i]]))(pytorch实现是将targets_flat所谓索引)
 loss = F.cross_entropy(logits_flat, targets_flat)  # 输出标量值
3.3 示例计算

# 以第一个样本的第一个位置为例:
 # 输入token=24,目标token=43
 # 模型输出的logits[0,0]是一个65维向量(这里logits.shape=[4,8,65]),例如:
 logits_example = logits[0,0]  # 形状 (65,)
 probs_example = F.softmax(logits_example, dim=-1)  # 形状 (65,)
 # 假设probs_example[43] = 0.15(模型预测下一个token=43的概率为15%)
 # 则此位置的损失为 -log(0.15) ≈ 1.897 (注意-log(p)是一个x范围在[0,1]之间单调递减函数)
# 最终损失是所有32个位置类似计算的均值。
 # 初始损失约为4.87(接近均匀分布的理论值 -ln(1/65)≈4.17)
4. 测试生成文本
# super simple bigram model
class BigramLanguageModel(nn.Module):def __init__(self, vocab_size):super().__init__()# each token directly reads off the logits for the next token from a lookup tableself.token_embedding_table = nn.Embedding(vocab_size, vocab_size)def forward(self, idx, targets=None):# idx and targets are both (B,T) tensor of integerslogits = self.token_embedding_table(idx) # (B,T,C)if targets is None:loss = Noneelse:B, T, C = logits.shapelogits = logits.view(B*T, C)targets = targets.view(B*T)loss = F.cross_entropy(logits, targets)return logits, lossdef generate(self, idx, max_new_tokens):# idx is (B, T) array of indices in the current contextfor _ in range(max_new_tokens):# get the predictionslogits, loss = self(idx) # 没有输入target时,返回的logits未被展平。  # focus only on the last time steplogits = logits[:, -1, :] # (B,T,C) -> (B, C)# apply softmax to get probabilitiesprobs = F.softmax(logits, dim=-1) # (B, C)# sample from the distributionidx_next = torch.multinomial(probs, num_samples=1) # (B, 1)# append sampled index to the running sequenceidx = torch.cat((idx, idx_next), dim=1) # (B, T+1)return idx 
以下是Bigram模型生成过程的逐步详解,以输入序列[24, 43, 58, 5, 57, 1, 46, 43]为例,说明模型如何从初始输入[24]开始逐步预测下一个词: 
1. 初始输入:[24]
 
-  
输入形状:
idx = [[24]](B=1批次,T=1序列长度)。 -  
前向传播:
-  
通过嵌入层,模型输出
logits形状为(1, 1, 65),表示对当前词24的下一个词的预测分数。 -  
假设
logits[0, 0, 43] = 5.0(词43的logit较高),其他位置logits较低(如logits[0, 0, :] = [..., 5.0, ...])。 
 -  
 -  
概率分布:
-  
对logits应用softmax,得到概率分布
probs。例如:probs = [0.01, ..., 0.8(对应43), 0.01, ...] # 总和为1
 
 -  
 -  
采样:
-  
根据
probs,使用torch.multinomial采样,选中词43的概率最大。 
 -  
 -  
更新输入:
-  
将
43拼接到序列末尾,新输入为idx = [[24, 43]](形状(1, 2))。 
 -  
 
2. 输入:[24, 43]
 
-  
前向传播:
-  
模型处理整个序列,输出
logits形状为(1, 2, 65),对应两个位置的预测:-  
第1个位置(词
24)预测下一个词(已生成43)。 -  
第2个位置(词
43)预测下一个词。 
 -  
 -  
提取最后一个位置的logits:
logits[:, -1, :](形状(1, 65))。 -  
假设
logits[0, -1, 58] = 6.0(词58的logit较高)。 
 -  
 -  
概率分布:
-  
probs = [0.01, ..., 0.85(对应58), 0.01, ...]。 
 -  
 -  
采样:
-  
选中词
58。 
 -  
 -  
更新输入:
-  
新输入为
idx = [[24, 43, 58]](形状(1, 3))。 
 -  
 
3. 输入:[24, 43, 58]
 
-  
前向传播:
-  
logits形状为(1, 3, 65)。 -  
提取最后一个位置(词
58)的logits,假设logits[0, -1, 5] = 4.5。 
 -  
 -  
概率分布:
-  
probs = [0.01, ..., 0.7(对应5), ...]。 
 -  
 -  
采样:
-  
选中词
5。 
 -  
 -  
更新输入:
-  
新输入为
idx = [[24, 43, 58, 5]](形状(1, 4))。 
 -  
 
4. 重复生成直到序列完成
-  
后续步骤:
-  
输入
[24, 43, 58, 5]→ 预测词57。 -  
输入
[24, 43, 58, 5, 57]→ 预测词1。 -  
输入
[24, 43, 58, 5, 57, 1]→ 预测词46。 -  
输入
[24, 43, 58, 5, 57, 1, 46]→ 预测词43。 
 -  
 -  
最终序列:
-  
idx = [[24, 43, 58, 5, 57, 1, 46, 43]]。 
 -  
 
注意:上面输入序列是越来越长的,为何说预测下一个词只跟上一个词有关?如果只跟一个词有关,为何不每次只输入一个词,然后预测下一个词?
虽然理论上可以仅传递最后一个词,但实际实现中传递完整序列的原因(视频作者说的,固定generate函数形式,我这里理解的是代码简洁):
- 代码简洁性:无需在每次生成时截取最后一个词,直接复用统一的前向传播逻辑;
 
实验验证
若修改代码,每次仅传递最后一个词:
def generate(self, idx, max_new_tokens):for _ in range(max_new_tokens):last_token = idx[:, -1:]          # 仅取最后一个词 (B, 1)logits, _ = self(last_token)       # 输出形状 (B, 1, C)probs = F.softmax(logits[:, -1, :], dim=-1)idx_next = torch.multinomial(probs, num_samples=1)idx = torch.cat((idx, idx_next), dim=1)return idx 
相关文章:
简单的二元语言模型bigram实现
内容总结归纳自视频:【珍藏】从头开始用代码构建GPT - 大神Andrej Karpathy 的“神经网络从Zero到Hero 系列”之七_哔哩哔哩_bilibili 项目:https://github.com/karpathy/ng-video-lecture Bigram模型是基于当前Token预测下一个Token的模型。例如&#x…...
【清华大学】实用DeepSeek赋能家庭教育 56页PDF文档完整版
清华大学-56页:实用DeepSeek赋能家庭教育.pdf https://pan.baidu.com/s/1BUweVDeG2M8-t0QaIs3LHQ?pwd1234 提取码: 1234 或 https://pan.quark.cn/s/8a9473493bb0 《实用DeepSeek赋能家庭教育》基于清华大学研究成果,系统阐述了DeepSeek人工智能技…...
黑洞如何阻止光子逃逸
虽然涉及广义相对论,但广义相对论说的是大质量物体对周围空间的影响,而不是说周围空间和空间中的光子之间的关系。也就是说,若讨论光子逃逸问题,则不必限定于大质量的前提,也就是说,若质量周围被扭曲的空间…...
1.4 单元测试与热部署
本次实战实现Spring Boot的单元测试与热部署功能。单元测试方面,通过JUnit和Mockito等工具,结合SpringBootTest注解,可以模拟真实环境对应用组件进行独立测试,验证逻辑正确性,提升代码质量。具体演示了HelloWorld01和H…...
window系统中的start命令详解
start 是 Windows 系统中用于启动新进程或打开新窗口来运行指定程序或命令的命令。以下是对 start 命令参数的详细解释: 基本语法 start ["title"] [/Dpath] [/I] [/MIN] [/MAX] [/SEPARATE | /SHARED] [/LOW | /NORMAL | /HIGH | /REALTIME | /ABOVENO…...
AI编程工具节选
1、文心快码 百度基于文心大模型推出的一款智能编码助手, 官网地址:文心快码(Baidu Comate)更懂你的智能代码助手 2、通义灵码 阿里云出品的一款基于通义大模型的智能编码辅助工具, 官网地址:通义灵码_你的智能编码助手-阿里云 …...
正则表达式,idea,插件anyrule
package lx;import java.util.regex.Pattern;public class lxx {public static void main(String[] args) {//正则表达式//写一个电话号码的正则表达式String regex "1[3-9]\\d{9}";//第一个数字是1,第二个数字是3-9,后面跟着9个数字…...
原生iOS集成react-native (react-native 0.65+)
由于官方文档比较老,很多配置都不能用,集成的时候遇到很多坑,简单的整理一下 时间节点:2021年9月1日 本文主要提供一些配置信息以及错误信息解决方案,具体步骤可以参照官方文档 原版文档:https://reactnative.dev/docs…...
java错题总结
本篇文章用来记录学习javaSE以来的错题 解答:重载要求俩个方法的名字相同,但参数的类型或者个数不同,但是不要求返回类型相同,所以A正确。 重写还需要要求返回类型相同(呈现父子类关系也可以,但是属于特例&…...
【商城实战(10)】解锁商品信息录入与展示的技术密码
【商城实战】专栏重磅来袭!这是一份专为开发者与电商从业者打造的超详细指南。从项目基础搭建,运用 uniapp、Element Plus、SpringBoot 搭建商城框架,到用户、商品、订单等核心模块开发,再到性能优化、安全加固、多端适配…...
2025年主流原型工具测评:墨刀、Axure、Figma、Sketch
2025年主流原型工具测评:墨刀、Axure、Figma、Sketch 要说2025年国内产品经理使用的主流原型设计工具,当然是墨刀、Axure、Figma和Sketch了,但是很多刚入行的产品经理不了解自己适合哪些工具,本文将从核心优势、局限短板、协作能…...
MDM 如何彻底改变医疗设备的远程管理
在现代医疗行业迅速发展的格局中,医院和诊所越来越依赖诸如医疗平板和移动工作站等移动设备。这些设备在提高工作效率和提供卓越的患者护理方面发挥着关键作用。然而,随着它们的广泛使用,也带来了一系列挑战,例如在不同地点确保数…...
OpenCV计算摄影学(18)平滑图像中的纹理区域同时保留边缘信息函数textureFlattening()
操作系统:ubuntu22.04 OpenCV版本:OpenCV4.9 IDE:Visual Studio Code 编程语言:C11 算法描述 cv::textureFlattening 是 OpenCV 中用于图像处理的一个函数,旨在平滑图像中的纹理区域,同时保留边缘信息。该技术特别适…...
用DeepSeek学Android开发:Android初学者遇到的常见问题有哪些?如何解决?
答案来自 DeepSeek Q: Android初学者遇到的常见问题有哪些?如何解决? A: Android初学者在学习过程中常会遇到以下问题及对应的解决方法,按类别整理如下: 一、开发环境问题 Android Studio安装或配置问题 问题:安装失…...
springboot 集成 MongoDB 基础篇
demo架构: Book Controller: package com.zy.controller;import com.zy.entity.Book; import com.zy.service.MongoDbService; import org.springframework.beans.factory.annotation.Autowired; import org.springframework.web.bind.annotation.Get…...
大白话html语义化标签优势与应用场景
大白话html语义化标签优势与应用场景 大白话解释 语义化标签就是那些名字能让人一看就大概知道它是用来做什么的标签。以前我们经常用<div>来做各种布局,但是<div>本身没有什么实际的含义,就像一个没有名字的盒子。而语义化标签就像是有名…...
恶劣天候三维目标检测论文列表整理
恶劣天候三维目标检测论文列表 图摘自Kradar 🏠 介绍 Hi,这是有关恶劣天气下三维目标检测的论文列表。主要是来源于近3年研究过程中认为有意义的文章。希望能为新入门的研究者提供一些帮助。 可能比较简陋,存在一定的遗漏,欢迎…...
conda的环境起的jupyter用不了已经安装的包如何解决
当你在使用Conda环境中的Jupyter Notebook时遇到无法读取某些库或模块的问题,通常是由以下几个原因引起的: 环境未激活:确保你已经在正确的Conda环境中激活了Jupyter Notebook。 库未安装:可能你需要的库没有在当前的Conda环境中…...
蓝桥杯题型
蓝桥杯题型分类 二分 123 传送门 1. 小区间的构成 假设数列的构成是如下形式: 第 1 个区间包含 1 个元素(1)。第 2 个区间包含 2 个元素(1 2)。第 3 个区间包含 3 个元素(1 2 3)。第 4 个区…...
STM32-I2C通信协议
一:I2C通信协议 就是在串口通信上满足四个要求 要求1:删掉一根通信线,防止资源浪费,只能在同一根线上进行发送和接收要求2:需要一个应答机制,没发送一个字节都有一次应答要求3:一根线上能同时…...
OpenLayers 可视化之热力图
注:当前使用的是 ol 5.3.0 版本,天地图使用的key请到天地图官网申请,并替换为自己的key 热力图(Heatmap)又叫热点图,是一种通过特殊高亮显示事物密度分布、变化趋势的数据可视化技术。采用颜色的深浅来显示…...
Java如何权衡是使用无序的数组还是有序的数组
在 Java 中,选择有序数组还是无序数组取决于具体场景的性能需求与操作特点。以下是关键权衡因素及决策指南: ⚖️ 核心权衡维度 维度有序数组无序数组查询性能二分查找 O(log n) ✅线性扫描 O(n) ❌插入/删除需移位维护顺序 O(n) ❌直接操作尾部 O(1) ✅内存开销与无序数组相…...
Java入门学习详细版(一)
大家好,Java 学习是一个系统学习的过程,核心原则就是“理论 实践 坚持”,并且需循序渐进,不可过于着急,本篇文章推出的这份详细入门学习资料将带大家从零基础开始,逐步掌握 Java 的核心概念和编程技能。 …...
浅谈不同二分算法的查找情况
二分算法原理比较简单,但是实际的算法模板却有很多,这一切都源于二分查找问题中的复杂情况和二分算法的边界处理,以下是博主对一些二分算法查找的情况分析。 需要说明的是,以下二分算法都是基于有序序列为升序有序的情况…...
Ubuntu Cursor升级成v1.0
0. 当前版本低 使用当前 Cursor v0.50时 GitHub Copilot Chat 打不开,快捷键也不好用,当看到 Cursor 升级后,还是蛮高兴的 1. 下载 Cursor 下载地址:https://www.cursor.com/cn/downloads 点击下载 Linux (x64) ,…...
Qt 事件处理中 return 的深入解析
Qt 事件处理中 return 的深入解析 在 Qt 事件处理中,return 语句的使用是另一个关键概念,它与 event->accept()/event->ignore() 密切相关但作用不同。让我们详细分析一下它们之间的关系和工作原理。 核心区别:不同层级的事件处理 方…...
tomcat指定使用的jdk版本
说明 有时候需要对tomcat配置指定的jdk版本号,此时,我们可以通过以下方式进行配置 设置方式 找到tomcat的bin目录中的setclasspath.bat。如果是linux系统则是setclasspath.sh set JAVA_HOMEC:\Program Files\Java\jdk8 set JRE_HOMEC:\Program Files…...
Vue 模板语句的数据来源
🧩 Vue 模板语句的数据来源:全方位解析 Vue 模板(<template> 部分)中的表达式、指令绑定(如 v-bind, v-on)和插值({{ }})都在一个特定的作用域内求值。这个作用域由当前 组件…...
【iOS】 Block再学习
iOS Block再学习 文章目录 iOS Block再学习前言Block的三种类型__ NSGlobalBlock____ NSMallocBlock____ NSStackBlock__小结 Block底层分析Block的结构捕获自由变量捕获全局(静态)变量捕获静态变量__block修饰符forwarding指针 Block的copy时机block作为函数返回值将block赋给…...
数据结构:泰勒展开式:霍纳法则(Horner‘s Rule)
目录 🔍 若用递归计算每一项,会发生什么? Horners Rule(霍纳法则) 第一步:我们从最原始的泰勒公式出发 第二步:从形式上重新观察展开式 🌟 第三步:引出霍纳法则&…...
