【Andrej Karpathy 神经网络从Zero到Hero】--2.语言模型的两种实现方式 (Bigram 和 神经网络)
目录
- 统计 Bigram 语言模型
- 质量评价方法
- 神经网络语言模型
【系列笔记】
【Andrej Karpathy 神经网络从Zero到Hero】–1. 自动微分autograd实践要点
本文主要参考 大神Andrej Karpathy 大模型讲座 | 构建makemore 系列之一:讲解语言建模的明确入门,演示
- 如何利用统计数值构建一个简单的 Bigram 语言模型
- 如何用一个神经网络来复现前面 Bigram 语言模型的结果,以此来展示神经网络相对于传统 n-gram 模型的拓展性。
统计 Bigram 语言模型
首先给定一批数据,每个数据是一个英文名字,例如:
['emma','olivia','ava','isabella','sophia','charlotte','mia','amelia','harper','evelyn']
Bigram语言模型的做法很简单,首先将数据中的英文名字都做成一个个bigram的数据
其中每个格子中是对应的二元组,eg: “rh” ,在所有数据中出现的次数。那么一个自然的想法是对于给定的字母,取其对应的行,将次数归一化转成概率值,然后根据概率分布抽取下一个可能的字母:
g = torch.Generator().manual_seed(2147483647)
P = N.float() # N 即为上述 counts 矩阵
P = P / P.sum(1, keepdims=True) # P是每行归一化后的概率值for i in range(5):out = []ix = 0 ## start符和end符都用 id=0 表示,这里是startwhile True:p = P[ix] # 当前字符为 ix 时,预测下一个字符的概率分布,实质是一个多项分布(即可能抽到的值有多个,eg: 掷色子是六项分布)ix = torch.multinomial(p, num_samples=1, replacement=True, generator=g).item()out.append(itos[ix])if ix == 0: ## 当运行到end符,停止生成breakprint(''.join(out))
输出类似于:
mor.
axx.
minaymoryles.
kondlaisah.
anchshizarie.
质量评价方法
我们还需要方法来评估语言模型的质量,一个直观的想法是:
P ( s 1 s 2 . . . s n ) = P ( s 1 ) P ( s 2 ∣ s 1 ) ⋯ P ( s n ∣ s n − 1 ) P(s_1s_2...s_n) = P(s_1)P(s_2|s_1)\cdots P(s_n|s_{n-1}) P(s1s2...sn)=P(s1)P(s2∣s1)⋯P(sn∣sn−1)
但上述计算方式有一个问题,概率值都是小于1的,当序列的长度比较长时,上述数值会趋于0,计算时容易下溢。因此实践中往往使用 l o g ( P ) log(P) log(P)来代替,为了可以对比不同长度的序列的预测效果,再进一步使用 l o g ( P ) / n log(P)/n log(P)/n 表示一个序列平均的质量。
上述统计 Bigram 模型在训练数据上的平均质量为:
log_likelihood = 0.0
n = 0for w in words: # 所有word里的二元组概率叠加chs = ['.'] + list(w) + ['.']for ch1, ch2 in zip(chs, chs[1:]):ix1 = stoi[ch1]ix2 = stoi[ch2]prob = P[ix1, ix2]logprob = torch.log(prob)log_likelihood += logprobn += 1 # 所有word里的二元组数量之和nll = -log_likelihood
print(f'{nll/n}') ## 值为 2.4764,表示前面做的bigram模型,对现有训练数据的置信度## 这个值越低表示当前模型越认可训练数据的质量,而由于训练数据是我们认为“好”的数据,因此反过来就说明这个模型好
但这里有一个问题是,例如:
log_likelihood = 0.0
n = 0#for w in words:
for w in ["andrejz"]:chs = ['.'] + list(w) + ['.']for ch1, ch2 in zip(chs, chs[1:]):ix1 = stoi[ch1]ix2 = stoi[ch2]prob = P[ix1, ix2]logprob = torch.log(prob)log_likelihood += logprobn += 1print(f'{ch1}{ch2}: {prob:.4f} {logprob:.4f}')print(f'{log_likelihood=}')
nll = -log_likelihood
print(f'{nll=}')
print(f'{nll/n}')
输出是
.a: 0.1377 -1.9829
an: 0.1605 -1.8296
nd: 0.0384 -3.2594
dr: 0.0771 -2.5620
re: 0.1336 -2.0127
ej: 0.0027 -5.9171
jz: 0.0000 -inf
z.: 0.0667 -2.7072
log_likelihood=tensor(-inf)
nll=tensor(inf)
inf
可以发现由于,jz 在计数矩阵 N 中为0,即数据中没有出现过,导致 log(loss) 变成了负无穷,这里为了避免这样的情况,需要做 平滑处理,即 P = N.float() 改成 P = (N+1).float(),这样上述代码输出变成:
.a: 0.1376 -1.9835
an: 0.1604 -1.8302
nd: 0.0384 -3.2594
dr: 0.0770 -2.5646
re: 0.1334 -2.0143
ej: 0.0027 -5.9004
jz: 0.0003 -7.9817
z.: 0.0664 -2.7122
log_likelihood=tensor(-28.2463)
nll=tensor(28.2463)
3.5307815074920654
避免了出现 inf 这种数据溢出问题。
神经网络语言模型
接下来尝试用神经网络的方式构建上述bigram语言模型:
# 构建训练数据
xs, ys = [], [] # 分别是前一个字符和要预测的下一个字符的id
for w in words[:5]:chs = ['.'] + list(w) + ['.']for ch1, ch2 in zip(chs, chs[1:]):ix1 = stoi[ch1]ix2 = stoi[ch2]print(ch1, ch2)xs.append(ix1)ys.append(ix2) xs = torch.tensor(xs)
ys = torch.tensor(ys)
# 输出示例:. e
# e m
# m m
# m a
# a .
# xs: tensor([ 0, 5, 13, 13, 1])
# ys: tensor([ 5, 13, 13, 1, 0])# 随机初始化一个 27*27 的参数矩阵
g = torch.Generator().manual_seed(2147483647)
W = torch.randn((27, 27), generator=g, requires_grad=True) # 基于正态分布随机初始化
# 前向传播
import torch.nn.functional as F
xenc = F.one_hot(xs, num_classes=27).float() # 将输入数据xs做成one-hot embedding
logits = xenc @ W # 用于模拟统计模型中的统计数值矩阵,由于 W 是基于正态分布采样,logits 并非直接是计数值,可以认为是 log(counts)
## tensor([[-0.5288, -0.5967, -0.7431, ..., 0.5990, -1.5881, 1.1731],
## [-0.3065, -0.1569, -0.8672, ..., 0.0821, 0.0672, -0.3943],
## [ 0.4942, 1.5439, -0.2300, ..., -2.0636, -0.8923, -1.6962],
## ...,
## [-0.1936, -0.2342, 0.5450, ..., -0.0578, 0.7762, 1.9665],
## [-0.4965, -1.5579, 2.6435, ..., 0.9274, 0.3591, -0.3198],
## [ 1.5803, -1.1465, -1.2724, ..., 0.8207, 0.0131, 0.4530]])
counts = logits.exp() # 将 log(counts) 还原成可以看作是 counts 的矩阵
## tensor([[ 0.5893, 0.5507, 0.4756, ..., 1.8203, 0.2043, 3.2321],
## [ 0.7360, 0.8548, 0.4201, ..., 1.0856, 1.0695, 0.6741],
## [ 1.6391, 4.6828, 0.7945, ..., 0.1270, 0.4097, 0.1834],
## ...,
## [ 0.8240, 0.7912, 1.7245, ..., 0.9438, 2.1732, 7.1459],
## [ 0.6086, 0.2106, 14.0621, ..., 2.5279, 1.4320, 0.7263],
## [ 4.8566, 0.3177, 0.2802, ..., 2.2722, 1.0132, 1.5730]])
probs = counts / counts.sum(1, keepdims=True) # 用于模拟统计模型中的概率矩阵,这其实即是 softmax 的实现
loss = -probs[torch.arange(5), ys].log().mean() # loss = log(P)/n, 这其实即是 cross-entropy 的实现
接下来可以通过loss.backward()来更新参数 W:
for k in range(100):# forward passxenc = F.one_hot(xs, num_classes=27).float() logits = xenc @ W # predict log-countscounts = logits.exp()probs = counts / counts.sum(1, keepdims=True) loss = -probs[torch.arange(num), ys].log().mean() + 0.01*(W**2).mean() ## 这里加上了L2正则,防止过拟合print(loss.item())# backward passW.grad = None # 每次反向传播前置为Noneloss.backward()# updateW.data += -50 * W.grad
注意这里 logits = xenc @ W 由于 xenc 是 one-hot 向量,因此这里 logits 相当于是抽出了 W 中的某一行,而结合 bigram 模型中,loss 实际上是在计算实际的 log(P[x_i, y_i]),那么可以认为这里 W 其实是在拟合 bigram 中的计数矩阵 N(不过实际是 logW 在拟合 N)。
另外上述神经网络的 loss 最终也是达到差不多 2.47 的最低 loss。这是合理的,因为从上面的分析可知,这个神经网络是完全在拟合 bigram 计数矩阵的,没有使用更复杂的特征提取方法,因此效果最终也会差不多。
这里 loss 中还加了一个 L2 正则,主要目的是压缩 W,使得它向全 0 靠近,这里的效果非常类似于 bigram 中的平滑手段,想象给一个极大的平滑:P = (N+10000).float()`,那么 P 会趋于一个均匀分布,而 W 全为 0 会导致 counts = logits.exp() 全为 1,即也在拟合一个均匀分布。这里前面的参数 0.01 即是用来调整平滑强度的,如果这个给的太大,那么平滑太大了,就会学成一个均匀分布(当然实际不会希望这样,所以不会给很大)
相关文章:
【Andrej Karpathy 神经网络从Zero到Hero】--2.语言模型的两种实现方式 (Bigram 和 神经网络)
目录 统计 Bigram 语言模型质量评价方法 神经网络语言模型 【系列笔记】 【Andrej Karpathy 神经网络从Zero到Hero】–1. 自动微分autograd实践要点 本文主要参考 大神Andrej Karpathy 大模型讲座 | 构建makemore 系列之一:讲解语言建模的明确入门,演示…...
Android MVC、MVP、MVVM三种架构的介绍和使用。
写在前面:现在随便出去面试Android APP相关的工作,面试官基本上都会提问APP架构相关的问题,用Java、kotlin写APP的话,其实就三种架构MVC、MVP、MVVM,MVC和MVP高度相似,区别不大,MVVM则不同&…...
python使用django搭建图书管理系统
大家好,你们喜欢的梦幻编织者回来了 随着计算机网络和信息技术的不断发展,人类信息交流的方式从根本上发生了改变,计算机技术、信息化技术在各个领域都得到了广泛的应用。图书馆的规模和数量都在迅速增长,馆内藏书也越来越多,管理…...
JavaScript系列06-深入理解 JavaScript 事件系统:从原生事件到 React 合成事件
JavaScript 事件系统是构建交互式 Web 应用的核心。本文从原生 DOM 事件到 React 的合成事件,内容涵盖: JavaScript 事件基础:事件类型、事件注册、事件对象事件传播机制:捕获、目标和冒泡阶段高级事件技术:事件委托、…...
大话机器学习三大门派:监督、无监督与强化学习
以武侠江湖为隐喻,系统阐述了机器学习的三大范式:监督学习(少林派)凭借标注数据精准建模,擅长图像分类等预测任务;无监督学习(逍遥派)通过数据自组织发现隐藏规律,…...
win11编译llama_cpp_python cuda128 RTX30/40/50版本
Geforce 50xx系显卡最低支持cuda128,llama_cpp_python官方源只有cpu版本,没有cuda版本,所以自己基于0.3.5版本源码编译一个RTX 30xx/40xx/50xx版本。 1. 前置条件 1. 访问https://developer.download.nvidia.cn/compute/cuda/12.8.0/local_…...
FY-3D MWRI亮温绘制
1、FY-3D MWRI介绍 风云三号气象卫星(FY-3)是我国自行研制的第二代极轨气象卫星,其有效载荷覆 盖了紫外、可见光、红外、微波等频段,其目标是实现全球全天候、多光谱、三维定量 探测,为中期数值天气预报提供卫星观测数…...
Codeforces1929F Sasha and the Wedding Binary Search Tree
目录 tags中文题面输入格式输出格式样例输入样例输出说明 思路代码 tags 组合数 二叉搜索树 中文题面 定义一棵二叉搜索树满足,点有点权,左儿子的点权 ≤ \leq ≤ 根节点的点权,右儿子的点权 ≥ \geq ≥ 根节点的点权。 现在给定一棵 …...
HBuilder X 使用 TortoiseSVN 设置快捷键方法
HBuilder X 使用 TortoiseSVN 设置快捷键方法 单文件:(上锁,解锁,提交,更新) 安装好 TortoiseSVN ,或者 按图操作: 1,工具栏中 【自定义快捷键】 2,点击 默认的快捷键设置&…...
Java jar包后台运行方式详解
目录 一、打包成 jar 文件二、后台运行 jar 文件三、示例四、总结在 Java 开发中,我们经常需要将应用程序打包成可执行的 jar 文件,并在后台运行。这种方式对于部署长时间运行的任务或需要持续监听事件的应用程序非常重要。本文将详细介绍如何实现 Java jar 包的后台运行,并…...
Refreshtoken 前端 安全 前端安全方面
网络安全 前端不需要过硬的网络安全方面的知识,但是能够了解大多数的网络安全,并且可以进行简单的防御前两三个是需要的 介绍一下常见的安全问题,解决方式,和小的Demo,希望大家喜欢 网络安全汇总 XSSCSRF点击劫持SQL注入OS注入请求劫持DDOS 在我看来,前端可以了解并且防御前…...
Mysql5.7-yum安装和更改mysql数据存放路径-2020年记录
记录下官网里用yum rpm源安装mysql, 1 官网下载rpm https://dev.mysql.com/downloads/repo/yum/ https://dev.mysql.com/doc/refman/5.7/en/linux-installation-yum-repo.html(附官网操作手册) wget https://repo.mysql.com//mysql80-community-release…...
[项目]基于FreeRTOS的STM32四轴飞行器: 七.遥控器按键
基于FreeRTOS的STM32四轴飞行器: 七.遥控器 一.遥控器按键摇杆功能说明二.摇杆和按键的配置三.按键扫描 一.遥控器按键摇杆功能说明 两个手柄四个ADC。 左侧手柄: 前后推为飞控油门,左右推为控制飞机偏航角。 右侧手柄: 控制飞机飞行方向&a…...
Android15使用FFmpeg解码并播放MP4视频完整示例
效果: 1.编译FFmpeg库: 下载FFmpeg-kit的源码并编译生成安装平台库 2.复制生成的FFmpeg库so文件与包含目录到自己的Android下 如果没有prebuiltLibs目录,创建一个,然后复制 包含目录只复制arm64-v8a下...
numpy常用函数详解
在深度神经网络代码中经常用到numpy库的一些函数,很多看过之后很容易忘记,本文对经常使用的函数进行归纳总结。 np.arange arange是numpy一个常用的函数,该函数主要用于创建等差数列。它的使用方法如下所示: numpy.arange([star…...
安装树莓派3B+环境(嵌入式开发)
一、环境配置 1、下载树莓派镜像工具 点击进入下载连接 进入网站,点击下载即可。 2、配置wifi及ssh 将SD卡插入读卡器,再接入电脑,随后打开Raspberry Pi Imager下载工具, 选择Raspberry Pi 3 选择64位的操作系统 选择SD卡 选择…...
深度学习/强化学习调参技巧
深度调优策略 1. 学习率调整 技巧:学习率是最重要的超参数之一。过大可能导致训练不稳定,过小则收敛速度慢。可以使用学习率衰减(Learning Rate Decay)或自适应学习率方法(如Adam、RMSprop)来动态调整学习…...
p5.js:sound(音乐)可视化,动画显示音频高低变化
本文通过4个案例介绍了使用 p5.js 进行音乐可视化的实践,包括将音频振幅转化为图形、生成波形图。 承上一篇:vite:初学 p5.js demo 画圆圈 cd p5-demo copy .\node_modules\p5\lib\p5.min.js . copy .\node_modules\p5\lib\addons\p5.soun…...
Linux下安装elasticsearch(Elasticsearch 7.17.23)
Elasticsearch 是一个分布式的搜索和分析引擎,能够以近乎实时的速度存储、搜索和分析大量数据。它被广泛应用于日志分析、全文搜索、应用程序监控等场景。 本文将带你一步步在 Linux 系统上安装 Elasticsearch 7.17.23 版本,并完成基本的配置࿰…...
plt和cv2有不同的图像表示方式和颜色通道顺序
在处理图像时,matplotlib.pyplot (简称 plt) 和 OpenCV (简称 cv2) 有不同的图像表示方式和颜色通道顺序。了解这些区别对于正确处理和显示图像非常重要。 1. 图像形状和颜色通道顺序 matplotlib.pyplot (plt) 形状:plt 通常使用 (height, width, cha…...
【The Rap of China】2018
中国新说唱第一季,2018 2018年4月13日,该节目通过官方微博宣布,其第二季将更名为《中国新说唱》。 《中国新说唱2018》由张震岳、MC Hotdog、潘玮柏、邓紫棋、WYF 担任明星制作人; 艾热获得冠军、那吾克热玉素甫江获得亚军、ICE…...
通义万相2.1开源版本地化部署攻略,生成视频再填利器
2025 年 2 月 25 日晚上 11:00 通义万相 2.1 开源发布,前两周太忙没空搞它,这个周末,也来本地化部署一个,体验生成效果如何,总的来说,它在国内文生视频、图生视频的行列处于领先位置,…...
YOLOv10改进之MHAF(多分支辅助特征金字塔)
YOLOv10架构 YOLOv10的架构主要由 主干网络、特征金字塔和预测头 三部分组成。主干网络采用改进的Darknet结构,增强特征提取能力。特征金字塔模块使用多尺度特征融合技术,提高对不同大小目标的检测效果。预测头则负责生成最终的检测结果。这种结构设计使得YOLOv10在保持高效…...
好玩的谷歌浏览器插件-自定义谷歌浏览器光标皮肤插件-Chrome 的自定义光标
周末没有啥事 看到了一个非常有意思的插件 就是 在使用谷歌浏览器的时候,可以把鼠标的默认样式换一个皮肤。就像下面的这种样子。 实际谷歌浏览器插件开发对于有前端编程基础的小伙伴 还是比较容易的,实际也是写 html css js 。 所以这个插件使用的技术…...
svn删除所有隐藏.svn文件,文件夹脱离svn控制
新建一个文件,取名remove-svn-folders.reg,输入如下内容: Windows Registry Editor Version 5.00 [HKEY_LOCAL_MACHINE\SOFTWARE\Classes\Folder\shell\DeleteSVN] "Delete SVN Folders" [HKEY_LOCAL_MACHINE\SOFTWARE\Class…...
六十天前端强化训练之第十二天之闭包深度解析
欢迎来到编程星辰海的博客讲解 目录 第一章:闭包的底层运行机制 1.1 词法环境(Lexical Environment)的构成JavaScript 引擎通过三个关键组件管理作用域: 1.2 作用域链的创建过程当函数被定义时: 1.3 闭包变量的生命…...
DeepSeek R1-32B医疗大模型的完整微调实战分析(全码版)
DeepSeek R1-32B微调实战指南 ├── 1. 环境准备 │ ├── 1.1 硬件配置 │ │ ├─ 全参数微调:4*A100 80GB │ │ └─ LoRA微调:单卡24GB │ ├── 1.2 软件依赖 │ │ ├─ PyTorch 2.1.2+CUDA │ │ └─ Unsloth/ColossalAI │ └── 1.3 模…...
10.2 继承与多态
文章目录 继承多态 继承 继承的作用是代码复用。派生类自动获得基类的除私有成员外的一切。基类描述一般特性,派生类提供更丰富的属性和行为。在构造派生类时,其基类构造函数先被调用,然后是派生类构造函数。在析构时顺序刚好相反。 // 基类…...
[网络爬虫] 动态网页抓取 — Selenium 元素定位
🌟想系统化学习爬虫技术?看看这个:[数据抓取] Python 网络爬虫 - 学习手册-CSDN博客 在使用 Selenium 时,往往需要先定位到指定元素,然后再执行相应的操作。例如,再向文本输入框中输入文字之前,…...
静态网页的爬虫(以电影天堂为例)
一、电影天堂的网址(url) 电影天堂_免费电影_迅雷电影下载_电影天堂网最好的迅雷电影下载网,分享最新电影,高清电影、综艺、动漫、电视剧等下载!https://dydytt.net/index.htm 我们要爬取这个页面上的内容 二、代码…...
