简单的二元语言模型bigram实现
- 内容总结归纳自视频:【珍藏】从头开始用代码构建GPT - 大神Andrej Karpathy 的“神经网络从Zero到Hero 系列”之七_哔哩哔哩_bilibili
- 项目:https://github.com/karpathy/ng-video-lecture
Bigram模型是基于当前Token预测下一个Token的模型。例如,如果输入序列是`[A, B, C]`,那么模型会根据`A`预测`B`,根据`B`预测`C`,依此类推,实现自回归生成。在生成新Token时,通常只需要最后一个Token的信息,因为每个预测仅依赖于当前Token。
1. 训练batch数据形式
训练数据是:

训练目标是:

2. 定义词嵌入层
nn.Embedding 层输出的是可学习的浮点数,将token索引 (B,T) 直接映射为logits,即输入(4,8),输出 (4,8,65),其中输入每个数字,被映射成logit向量(这些值通过 F.cross_entropy 内部自动进行 softmax 转换为概率分布),比如上面输入tokens有个24被映射成如下。
logits = [1.0, 0.5, -2.0, ..., 3.2] # 共65个浮点数
softmax后得到。
probs = [0.15, 0.12, 0.01, ..., 0.20] # 和为1的概率分布
这样输出的是每个位置的概率分布。
交叉熵函数会自动计算每个位置的概率分布与真实标签之间的损失,并取平均。
简单的大语言模型,基于Bigram的结构,即每个token仅根据前一个token来预测下一个token。具体实现如下。
from torch.nn import functional as F # 导入PyTorch函数模块
torch.manual_seed(1337) # 固定随机种子保证结果可复现class BigramLanguageModel(nn.Module): # 定义Bigram语言模型类def __init__(self, vocab_size):super().__init__() # 继承父类初始化方法# 定义词嵌入层:将token索引直接映射为logits# 输入输出维度均为vocab_size(词汇表大小)self.token_embedding_table = nn.Embedding(vocab_size, vocab_size)def forward(self, idx, targets):# 前向传播函数# idx: inputs, 输入序列 (B, T),B=批次数,T=序列长度# targets: 目标序列 (B, T)# (4,8) -> (4,8,65)logits = self.token_embedding_table(idx) # (B, T, C)# 通过嵌入层获得每个位置的概率分布,C=词汇表大小# (4*8,C), (4*8,) -> (1,)B, T, C = logits.shape # 解包维度:批次数、序列长度、词表大小logits = logits.view(B*T, C) # 展平为二维张量 (B*T, C)targets = targets.view(B*T) # 目标展平为一维张量 (B*T)loss = F.cross_entropy(logits, targets) # 计算交叉熵损失return logits, loss # 返回logits(未归一化概率)和损失值# 假设 vocab_size=65(例如52字母+标点)
vocab_size = 65
m = BigramLanguageModel(vocab_size) # 实例化模型# 假设输入数据(代码中未定义):
# xb: 输入批次 (B=4, T=8),例如 tensor([[1,2,3,...], ...])
# yb: 目标批次 (B=4, T=8)
logits, loss = m(xb, yb) # 执行前向传播print(logits.shape) # 输出logits形状:torch.Size([32, 65])
# 解释:32 = B*T = 4*8,65=词表大小(每个位置65种可能)print(loss) # 输出损失值:tensor(4.8786, grad_fn=<NllLossBackward>)
# 解释:初始随机参数下,损失值约为-ln(1/65)=4.17,实际值因参数初始化略有波动
3. 代码逻辑分步解释
# 假设输入和目标的形状均为 (B=4, T=8)
# 输入示例(第一个样本):
inputs[0] = [24, 43, 58, 5, 57, 1, 46, 43]
targets[0] = [43, 58, 5, 57, 1, 46, 43, 39]
3.1 Softmax后的概率分布意义

当模型处理输入序列时,每个位置会输出一个长度为vocab_size的logits向量。即
输入: (4,8);
输出:(4,8,65). 65维度向量是每个输入token的下一个token的概率分布。
例如,当输入序列为 [24, 43, 58, 5, 57, 1, 46, 43] 时:
- 在第1个位置(token=24),模型预测下一个token(对应target=43)的概率分布p[0].shape=(65,),下一个输出是43的概率为p[0][target[0]]=p[0][43];
- 在第2个位置(token=43),模型预测下一个token(对应target=58)的概率分布p[1].shape=(65),下一个输出是58的概率为p[1][target[1]]=p[1][58];
- 以此类推,每个位置的logits经过softmax后得到一个概率分布,即每个输入位置,都会预测下一个token概率分布。
具体来说:
logits.shape = (4, 8, 65) → softmax后形状不变->p.shape=(4,8,65),但每行的65个值变为概率(和为1)
这些概率表示模型认为「当前token的下一个token」是词汇表中各token的可能性。
3.2 交叉熵计算步骤

假设logits初始形状为 (4, 8, 65)
B, T, C = logits.shape # B=4, T=8, C=65
# 展平logits和targets:
logits_flat = logits.view(B*T, C) # 形状 (32, 65)
targets_flat = targets.view(B*T) # 形状 (32,)

# 交叉熵计算(PyTorch内部过程):
# 对logits_flat的每一行(共32行)做softmax,得到概率分布probs (32, 65)
# 对每个样本i,取probs[i][targets_flat[i]],即真实标签对应的预测概率(此概率是下一个token是targets_flat[i]的概率)
# 计算负对数损失:loss = -mean(log(probs[i][targets_flat[i]]))(pytorch实现是将targets_flat所谓索引)
loss = F.cross_entropy(logits_flat, targets_flat) # 输出标量值
3.3 示例计算

# 以第一个样本的第一个位置为例:
# 输入token=24,目标token=43
# 模型输出的logits[0,0]是一个65维向量(这里logits.shape=[4,8,65]),例如:
logits_example = logits[0,0] # 形状 (65,)
probs_example = F.softmax(logits_example, dim=-1) # 形状 (65,)
# 假设probs_example[43] = 0.15(模型预测下一个token=43的概率为15%)
# 则此位置的损失为 -log(0.15) ≈ 1.897 (注意-log(p)是一个x范围在[0,1]之间单调递减函数)
# 最终损失是所有32个位置类似计算的均值。
# 初始损失约为4.87(接近均匀分布的理论值 -ln(1/65)≈4.17)
4. 测试生成文本
# super simple bigram model
class BigramLanguageModel(nn.Module):def __init__(self, vocab_size):super().__init__()# each token directly reads off the logits for the next token from a lookup tableself.token_embedding_table = nn.Embedding(vocab_size, vocab_size)def forward(self, idx, targets=None):# idx and targets are both (B,T) tensor of integerslogits = self.token_embedding_table(idx) # (B,T,C)if targets is None:loss = Noneelse:B, T, C = logits.shapelogits = logits.view(B*T, C)targets = targets.view(B*T)loss = F.cross_entropy(logits, targets)return logits, lossdef generate(self, idx, max_new_tokens):# idx is (B, T) array of indices in the current contextfor _ in range(max_new_tokens):# get the predictionslogits, loss = self(idx) # 没有输入target时,返回的logits未被展平。 # focus only on the last time steplogits = logits[:, -1, :] # (B,T,C) -> (B, C)# apply softmax to get probabilitiesprobs = F.softmax(logits, dim=-1) # (B, C)# sample from the distributionidx_next = torch.multinomial(probs, num_samples=1) # (B, 1)# append sampled index to the running sequenceidx = torch.cat((idx, idx_next), dim=1) # (B, T+1)return idx
以下是Bigram模型生成过程的逐步详解,以输入序列[24, 43, 58, 5, 57, 1, 46, 43]为例,说明模型如何从初始输入[24]开始逐步预测下一个词:
1. 初始输入:[24]
-
输入形状:
idx = [[24]](B=1批次,T=1序列长度)。 -
前向传播:
-
通过嵌入层,模型输出
logits形状为(1, 1, 65),表示对当前词24的下一个词的预测分数。 -
假设
logits[0, 0, 43] = 5.0(词43的logit较高),其他位置logits较低(如logits[0, 0, :] = [..., 5.0, ...])。
-
-
概率分布:
-
对logits应用softmax,得到概率分布
probs。例如:probs = [0.01, ..., 0.8(对应43), 0.01, ...] # 总和为1
-
-
采样:
-
根据
probs,使用torch.multinomial采样,选中词43的概率最大。
-
-
更新输入:
-
将
43拼接到序列末尾,新输入为idx = [[24, 43]](形状(1, 2))。
-
2. 输入:[24, 43]
-
前向传播:
-
模型处理整个序列,输出
logits形状为(1, 2, 65),对应两个位置的预测:-
第1个位置(词
24)预测下一个词(已生成43)。 -
第2个位置(词
43)预测下一个词。
-
-
提取最后一个位置的logits:
logits[:, -1, :](形状(1, 65))。 -
假设
logits[0, -1, 58] = 6.0(词58的logit较高)。
-
-
概率分布:
-
probs = [0.01, ..., 0.85(对应58), 0.01, ...]。
-
-
采样:
-
选中词
58。
-
-
更新输入:
-
新输入为
idx = [[24, 43, 58]](形状(1, 3))。
-
3. 输入:[24, 43, 58]
-
前向传播:
-
logits形状为(1, 3, 65)。 -
提取最后一个位置(词
58)的logits,假设logits[0, -1, 5] = 4.5。
-
-
概率分布:
-
probs = [0.01, ..., 0.7(对应5), ...]。
-
-
采样:
-
选中词
5。
-
-
更新输入:
-
新输入为
idx = [[24, 43, 58, 5]](形状(1, 4))。
-
4. 重复生成直到序列完成
-
后续步骤:
-
输入
[24, 43, 58, 5]→ 预测词57。 -
输入
[24, 43, 58, 5, 57]→ 预测词1。 -
输入
[24, 43, 58, 5, 57, 1]→ 预测词46。 -
输入
[24, 43, 58, 5, 57, 1, 46]→ 预测词43。
-
-
最终序列:
-
idx = [[24, 43, 58, 5, 57, 1, 46, 43]]。
-
注意:上面输入序列是越来越长的,为何说预测下一个词只跟上一个词有关?如果只跟一个词有关,为何不每次只输入一个词,然后预测下一个词?
虽然理论上可以仅传递最后一个词,但实际实现中传递完整序列的原因(视频作者说的,固定generate函数形式,我这里理解的是代码简洁):
- 代码简洁性:无需在每次生成时截取最后一个词,直接复用统一的前向传播逻辑;
实验验证
若修改代码,每次仅传递最后一个词:
def generate(self, idx, max_new_tokens):for _ in range(max_new_tokens):last_token = idx[:, -1:] # 仅取最后一个词 (B, 1)logits, _ = self(last_token) # 输出形状 (B, 1, C)probs = F.softmax(logits[:, -1, :], dim=-1)idx_next = torch.multinomial(probs, num_samples=1)idx = torch.cat((idx, idx_next), dim=1)return idx
相关文章:
简单的二元语言模型bigram实现
内容总结归纳自视频:【珍藏】从头开始用代码构建GPT - 大神Andrej Karpathy 的“神经网络从Zero到Hero 系列”之七_哔哩哔哩_bilibili 项目:https://github.com/karpathy/ng-video-lecture Bigram模型是基于当前Token预测下一个Token的模型。例如&#x…...
【清华大学】实用DeepSeek赋能家庭教育 56页PDF文档完整版
清华大学-56页:实用DeepSeek赋能家庭教育.pdf https://pan.baidu.com/s/1BUweVDeG2M8-t0QaIs3LHQ?pwd1234 提取码: 1234 或 https://pan.quark.cn/s/8a9473493bb0 《实用DeepSeek赋能家庭教育》基于清华大学研究成果,系统阐述了DeepSeek人工智能技…...
黑洞如何阻止光子逃逸
虽然涉及广义相对论,但广义相对论说的是大质量物体对周围空间的影响,而不是说周围空间和空间中的光子之间的关系。也就是说,若讨论光子逃逸问题,则不必限定于大质量的前提,也就是说,若质量周围被扭曲的空间…...
1.4 单元测试与热部署
本次实战实现Spring Boot的单元测试与热部署功能。单元测试方面,通过JUnit和Mockito等工具,结合SpringBootTest注解,可以模拟真实环境对应用组件进行独立测试,验证逻辑正确性,提升代码质量。具体演示了HelloWorld01和H…...
window系统中的start命令详解
start 是 Windows 系统中用于启动新进程或打开新窗口来运行指定程序或命令的命令。以下是对 start 命令参数的详细解释: 基本语法 start ["title"] [/Dpath] [/I] [/MIN] [/MAX] [/SEPARATE | /SHARED] [/LOW | /NORMAL | /HIGH | /REALTIME | /ABOVENO…...
AI编程工具节选
1、文心快码 百度基于文心大模型推出的一款智能编码助手, 官网地址:文心快码(Baidu Comate)更懂你的智能代码助手 2、通义灵码 阿里云出品的一款基于通义大模型的智能编码辅助工具, 官网地址:通义灵码_你的智能编码助手-阿里云 …...
正则表达式,idea,插件anyrule
package lx;import java.util.regex.Pattern;public class lxx {public static void main(String[] args) {//正则表达式//写一个电话号码的正则表达式String regex "1[3-9]\\d{9}";//第一个数字是1,第二个数字是3-9,后面跟着9个数字…...
原生iOS集成react-native (react-native 0.65+)
由于官方文档比较老,很多配置都不能用,集成的时候遇到很多坑,简单的整理一下 时间节点:2021年9月1日 本文主要提供一些配置信息以及错误信息解决方案,具体步骤可以参照官方文档 原版文档:https://reactnative.dev/docs…...
java错题总结
本篇文章用来记录学习javaSE以来的错题 解答:重载要求俩个方法的名字相同,但参数的类型或者个数不同,但是不要求返回类型相同,所以A正确。 重写还需要要求返回类型相同(呈现父子类关系也可以,但是属于特例&…...
【商城实战(10)】解锁商品信息录入与展示的技术密码
【商城实战】专栏重磅来袭!这是一份专为开发者与电商从业者打造的超详细指南。从项目基础搭建,运用 uniapp、Element Plus、SpringBoot 搭建商城框架,到用户、商品、订单等核心模块开发,再到性能优化、安全加固、多端适配…...
2025年主流原型工具测评:墨刀、Axure、Figma、Sketch
2025年主流原型工具测评:墨刀、Axure、Figma、Sketch 要说2025年国内产品经理使用的主流原型设计工具,当然是墨刀、Axure、Figma和Sketch了,但是很多刚入行的产品经理不了解自己适合哪些工具,本文将从核心优势、局限短板、协作能…...
MDM 如何彻底改变医疗设备的远程管理
在现代医疗行业迅速发展的格局中,医院和诊所越来越依赖诸如医疗平板和移动工作站等移动设备。这些设备在提高工作效率和提供卓越的患者护理方面发挥着关键作用。然而,随着它们的广泛使用,也带来了一系列挑战,例如在不同地点确保数…...
OpenCV计算摄影学(18)平滑图像中的纹理区域同时保留边缘信息函数textureFlattening()
操作系统:ubuntu22.04 OpenCV版本:OpenCV4.9 IDE:Visual Studio Code 编程语言:C11 算法描述 cv::textureFlattening 是 OpenCV 中用于图像处理的一个函数,旨在平滑图像中的纹理区域,同时保留边缘信息。该技术特别适…...
用DeepSeek学Android开发:Android初学者遇到的常见问题有哪些?如何解决?
答案来自 DeepSeek Q: Android初学者遇到的常见问题有哪些?如何解决? A: Android初学者在学习过程中常会遇到以下问题及对应的解决方法,按类别整理如下: 一、开发环境问题 Android Studio安装或配置问题 问题:安装失…...
springboot 集成 MongoDB 基础篇
demo架构: Book Controller: package com.zy.controller;import com.zy.entity.Book; import com.zy.service.MongoDbService; import org.springframework.beans.factory.annotation.Autowired; import org.springframework.web.bind.annotation.Get…...
大白话html语义化标签优势与应用场景
大白话html语义化标签优势与应用场景 大白话解释 语义化标签就是那些名字能让人一看就大概知道它是用来做什么的标签。以前我们经常用<div>来做各种布局,但是<div>本身没有什么实际的含义,就像一个没有名字的盒子。而语义化标签就像是有名…...
恶劣天候三维目标检测论文列表整理
恶劣天候三维目标检测论文列表 图摘自Kradar 🏠 介绍 Hi,这是有关恶劣天气下三维目标检测的论文列表。主要是来源于近3年研究过程中认为有意义的文章。希望能为新入门的研究者提供一些帮助。 可能比较简陋,存在一定的遗漏,欢迎…...
conda的环境起的jupyter用不了已经安装的包如何解决
当你在使用Conda环境中的Jupyter Notebook时遇到无法读取某些库或模块的问题,通常是由以下几个原因引起的: 环境未激活:确保你已经在正确的Conda环境中激活了Jupyter Notebook。 库未安装:可能你需要的库没有在当前的Conda环境中…...
蓝桥杯题型
蓝桥杯题型分类 二分 123 传送门 1. 小区间的构成 假设数列的构成是如下形式: 第 1 个区间包含 1 个元素(1)。第 2 个区间包含 2 个元素(1 2)。第 3 个区间包含 3 个元素(1 2 3)。第 4 个区…...
STM32-I2C通信协议
一:I2C通信协议 就是在串口通信上满足四个要求 要求1:删掉一根通信线,防止资源浪费,只能在同一根线上进行发送和接收要求2:需要一个应答机制,没发送一个字节都有一次应答要求3:一根线上能同时…...
Python|GIF 解析与构建(5):手搓截屏和帧率控制
目录 Python|GIF 解析与构建(5):手搓截屏和帧率控制 一、引言 二、技术实现:手搓截屏模块 2.1 核心原理 2.2 代码解析:ScreenshotData类 2.2.1 截图函数:capture_screen 三、技术实现&…...
网络六边形受到攻击
大家读完觉得有帮助记得关注和点赞!!! 抽象 现代智能交通系统 (ITS) 的一个关键要求是能够以安全、可靠和匿名的方式从互联车辆和移动设备收集地理参考数据。Nexagon 协议建立在 IETF 定位器/ID 分离协议 (…...
《Qt C++ 与 OpenCV:解锁视频播放程序设计的奥秘》
引言:探索视频播放程序设计之旅 在当今数字化时代,多媒体应用已渗透到我们生活的方方面面,从日常的视频娱乐到专业的视频监控、视频会议系统,视频播放程序作为多媒体应用的核心组成部分,扮演着至关重要的角色。无论是在个人电脑、移动设备还是智能电视等平台上,用户都期望…...
AI Agent与Agentic AI:原理、应用、挑战与未来展望
文章目录 一、引言二、AI Agent与Agentic AI的兴起2.1 技术契机与生态成熟2.2 Agent的定义与特征2.3 Agent的发展历程 三、AI Agent的核心技术栈解密3.1 感知模块代码示例:使用Python和OpenCV进行图像识别 3.2 认知与决策模块代码示例:使用OpenAI GPT-3进…...
8k长序列建模,蛋白质语言模型Prot42仅利用目标蛋白序列即可生成高亲和力结合剂
蛋白质结合剂(如抗体、抑制肽)在疾病诊断、成像分析及靶向药物递送等关键场景中发挥着不可替代的作用。传统上,高特异性蛋白质结合剂的开发高度依赖噬菌体展示、定向进化等实验技术,但这类方法普遍面临资源消耗巨大、研发周期冗长…...
在rocky linux 9.5上在线安装 docker
前面是指南,后面是日志 sudo dnf config-manager --add-repo https://download.docker.com/linux/centos/docker-ce.repo sudo dnf install docker-ce docker-ce-cli containerd.io -y docker version sudo systemctl start docker sudo systemctl status docker …...
STM32F4基本定时器使用和原理详解
STM32F4基本定时器使用和原理详解 前言如何确定定时器挂载在哪条时钟线上配置及使用方法参数配置PrescalerCounter ModeCounter Periodauto-reload preloadTrigger Event Selection 中断配置生成的代码及使用方法初始化代码基本定时器触发DCA或者ADC的代码讲解中断代码定时启动…...
Vue2 第一节_Vue2上手_插值表达式{{}}_访问数据和修改数据_Vue开发者工具
文章目录 1.Vue2上手-如何创建一个Vue实例,进行初始化渲染2. 插值表达式{{}}3. 访问数据和修改数据4. vue响应式5. Vue开发者工具--方便调试 1.Vue2上手-如何创建一个Vue实例,进行初始化渲染 准备容器引包创建Vue实例 new Vue()指定配置项 ->渲染数据 准备一个容器,例如: …...
macOS多出来了:Google云端硬盘、YouTube、表格、幻灯片、Gmail、Google文档等应用
文章目录 问题现象问题原因解决办法 问题现象 macOS启动台(Launchpad)多出来了:Google云端硬盘、YouTube、表格、幻灯片、Gmail、Google文档等应用。 问题原因 很明显,都是Google家的办公全家桶。这些应用并不是通过独立安装的…...
ElasticSearch搜索引擎之倒排索引及其底层算法
文章目录 一、搜索引擎1、什么是搜索引擎?2、搜索引擎的分类3、常用的搜索引擎4、搜索引擎的特点二、倒排索引1、简介2、为什么倒排索引不用B+树1.创建时间长,文件大。2.其次,树深,IO次数可怕。3.索引可能会失效。4.精准度差。三. 倒排索引四、算法1、Term Index的算法2、 …...
