机器学习深度学习——softmax回归的简洁实现
👨🎓作者简介:一位即将上大四,正专攻机器学习的保研er
🌌上期文章:机器学习&&深度学习——softmax回归从零开始实现
📚订阅专栏:机器学习&&深度学习
希望文章对你们有所帮助
继续使用Fashion-MNIST数据集,并保持批量大小为256:
import torch
from torch import nn
from d2l import torch as d2lbatch_size = 256
train_iter, test_iter = d2l.load_data_fashion_mnist(batch_size)
softmax回归的简洁实现
- 初始化模型参数
- 重新审视softmax的实现
- 数学推导
- 交叉熵函数
- 优化算法
- 训练
初始化模型参数
softmax的输出层是一个全连接层,因此,为了实现模型,我们只需要在Sequential中添加一个带有10个输出的全连接层。当然这里的Sequential并不是必要的,但是他是深度模型的基础。我们仍旧以均值为0,标准差为0.01来随机初始化权重。
# pytorch不会隐式地调整输入的形状
# 因此在线性层前就定义了展平层flatten,来调整网络输入的形状
net = nn.Sequential(nn.Flatten(), nn.Linear(784, 10))def init_weights(m):if type(m) == nn.Linear:nn.init.normal_(m.weight, std=0.01)net.apply(init_weights) # 给net每一层跑一次init_weights函数
重新审视softmax的实现
数学推导
在之前的例子里,我们计算了模型的输出,然后将此输出送入交叉熵损失。看似合理,但是指数级计算可能会造成数值的稳定性问题。
回想一下之前的softmax函数:
y ^ j = e x p ( o j ) ∑ k e x p ( o k ) 其中 y ^ j 是预测的概率分布, o j 是未规范化的第 j 个元素 \hat{y}_j=\frac{exp(o_j)}{\sum_kexp(o_k)}\\ 其中\hat{y}_j是预测的概率分布,o_j是未规范化的第j个元素 y^j=∑kexp(ok)exp(oj)其中y^j是预测的概率分布,oj是未规范化的第j个元素
由于o中的一些数值会非常大,所以可能会让其指数值上溢,使得分子或分母变成inf,最后得到的预测值可能变成的0、inf或者nan。此时我们无法得到一个明确的交叉熵值。
提出解决这个问题的一个技巧:在继续softmax计算之前,先从所有的o中减去max(o),修改softmax函数的构造且不改变其返回值:
y ^ j = e x p ( o j − m a x ( o k ) ) e x p ( m a x ( o k ) ) ∑ k e x p ( o j − m a x ( o k ) ) e x p ( m a x ( o k ) ) \hat{y}_j=\frac{exp(o_j-max(o_k))exp(max(o_k))}{\sum_kexp(o_j-max(o_k))exp(max(o_k))} y^j=∑kexp(oj−max(ok))exp(max(ok))exp(oj−max(ok))exp(max(ok))
这样操作以后,可能会使得一些分子的exp(o-max(o))有接近0的值,即为下溢。这些值可能会四舍五入为0,这样就会使得预测值为0,那么此时要是取对数以后就会变为-inf。要是这样反向传播几步,我们可能会发现自己屏幕有一堆的nan。
尽管我们需要计算指数函数,但是我们最终会在计算交叉熵损失的时候会取他们的对数。尽管通过将softmax和交叉熵结合在一起,可以避免反向传播过程中可能会困扰我们的数值稳定性问题。如下面的式子:
l o g ( y ^ j ) = l o g ( e x p ( o j − m a x ( o k ) ) ∑ k e x p ( o k − m a x ( o k ) ) ) = l o g ( e x p ( o j − m a x ( o k ) ) ) − l o g ( ∑ k e x p ( o k − m a x ( o k ) ) ) = o j − m a x ( o k ) − l o g ( ∑ k e x p ( o k − m a x ( o k ) ) ) log(\hat{y}_j)=log(\frac{exp(o_j-max(o_k))}{\sum_kexp(o_k-max(o_k))})\\ =log(exp(o_j-max(o_k)))-log(\sum_kexp(o_k-max(o_k)))\\ =o_j-max(o_k)-log(\sum_kexp(o_k-max(o_k))) log(y^j)=log(∑kexp(ok−max(ok))exp(oj−max(ok)))=log(exp(oj−max(ok)))−log(k∑exp(ok−max(ok)))=oj−max(ok)−log(k∑exp(ok−max(ok)))
通过上式,我们避免了计算单独的exp(o-max(o)),而是直接使用o-max(o)。
因此,我们计算交叉熵函数的时候,传递的不是未规范化的预测o,而不是softmax。
但是我们也希望保留传统的softmax函数,以备我们要评估通过模型输出的概率。
交叉熵函数
在这里介绍一下交叉熵函数,以用于上面推导所需的需求:
torch.nn.CrossEntropyLoss(weight=None,ignore_index=-100,reduction='mean')
交叉熵函数是将LogSoftMax和NLLLoss集成到一个类中,通常用于多分类问题。其参数使用情况:
ignore_index:指定被忽略且对输入梯度没有贡献的目标值。
reduction:string类型的可选项,可在[none,mean,sum]中选。none表示不降维,返回和target一样的形状;mean表示对一个batch的损失求均值;sum表示对一个batch的损失求和。
weight:是一个一维的张量,包含n个元素,分别代表n类的权重,在训练样本不均衡时很有用,默认为None:
(1)当weight=None时,损失函数计算方式为
loss(x,class)=-log(exp(x[class])/Σexp(x[j]))=-x[class]+log(Σexp(x[j])
(2)当weight被指定时,损失函数计算方式为:
loss(x,class)=weight[class]×(-x[class]+log(Σexp(x[j]))
# 在交叉熵损失函数中传递未归一化的预测,并同时计算softmax及其导数
loss = nn.CrossEntropyLoss(reduction='none')
优化算法
# 优化算法
trainer = torch.optim.SGD(net.parameters(), lr=0.1)
训练
调用之前定义的训练函数来训练模型:
# 调用之前的训练函数来训练模型
num_epochs = 10
d2l.train_ch3(net, train_iter, test_iter, loss, num_epochs, trainer)
d2l.plt.show()
相关文章:

机器学习深度学习——softmax回归的简洁实现
👨🎓作者简介:一位即将上大四,正专攻机器学习的保研er 🌌上期文章:机器学习&&深度学习——softmax回归从零开始实现 📚订阅专栏:机器学习&&深度学习 希望文章对你…...

CPU利用率过高解决思路
文章目录 问题场景问题定位问题解决 本文参考: Linux服务器之CPU过高解决思路_linux cpu温度过高_Jeremy_Lee123的博客-CSDN博客 Java程序员必备:jstack命令解析 - 掘金 (juejin.cn) 重点问题!CPU利用率过高排查思路|原创 (qq.…...

Redis(三)—— Redis基本的事务操作、Redis实现乐观锁
一、Redis基本的事务操作 首先声明: redis的单条命令是保证原子性的(回想一下setnx k1 v1 k5 v5命令如果k1已经存在,那么k5也会设置失败)但是redis的事务不保证原子性!见下面“1.2 某条命令有错怎么办?”…...

SQLI_LABS攻击
目录 Less1 首先来爆字段 联合注入 判断注入点 爆数据库名 爆破表名 information_schema information_schmea.tables group_concat() 爆破列名 information_schema.columns 爆值 SQLMAP Less-2 -4 Less -5 布尔 数据库 表名 字段名 爆破值 SQLMAP Less-6 …...
如何查看 Chrome 网站有没有前端 JavaScript 报错?
您可以按照以下步骤在Chrome中查看网站是否存在前端JavaScript报错: 步骤1:打开Chrome浏览器并访问网站 首先,打开Chrome浏览器并访问您想要检查JavaScript报错的网站。 步骤2:打开开发者工具 在Chrome浏览器中,按…...

JS前端读取本地上传的File文件对象内容(包括Base64、text、JSON、Blob、ArrayBuffer等类型文件)
读取base64图片File file2Base64Image(file, cb) {const reader new FileReader();reader.readAsDataURL(file);reader.onload function (e) {cb && cb(e.target.result);//即为base64结果}; }, 读取text、JSON文件File readText(file, { onloadend } {}) {const re…...

【项目方案】OpenAI流式请求实现方案
文章目录 实现目的效果比对非stream模式stream模式实现方案方案思路总体描述前端方案对比event-source-polyfill代码示例前端实现遇到的问题与解决方法后端参考资料时序图关键代码示例后端实现时遇到的问题与解决方法实现目的 stream是OpenAI API中的一个参数,用于控制请求的…...

华为数通HCIP-IP组播基础
点到点业务:比如FTP,WEB业务,此类业务主要特点是不同的用户有不同的需求,比如用户A需要下载资料A,用户B需要下载资料B。此类业务一般由单播承载,服务器对于不同用户发送不同的点到点数据流。 ospf、isis…...

STM32 SPI学习
SPI 串行外设设备接口(Serial Peripheral Interface),是一种高速的,全双工,同步的通信总线。 SCK时钟信号由主机发出。 SPI接口主要应用在存储芯片。 SPI相关引脚:MOSI(输出数据线ÿ…...
分布式缓存与数据库的一致性记录
用户更新数据库,需要再去更新redis缓存,否则会造成缓存与数据库数据不一致 一致性的两种方法 1). 双写模式 更新完数据库之后,更新redis缓存数据 问题: 因为请求时间的问题,造成缓存数据不是最新的 数据。 原因:A先修…...
vue3的语法
main.js中写发生变化,并不兼容vue2的写法 //vue3 import { createApp } from vue import ./style.css import App from ./App.vuecreateApp(App).mount(#app)//vue2 import Vue from vue import ./style.css import App from ./App.vueconst vm new Vue({render:h…...
【git合并分支自定义提交消息】
开发分支 dev主分支 master 需求 dev分支开发完后合并到master分支自定义提交信息 通过 git merge dev --squash --no-commit此命令会拉取dev分支代码到当前分支,并不会自动提交,可以自己修改提交信息...

AttributeError: module ‘PyQt5.QtGui‘ has no attribute ‘QMainWindow‘
场景描述: 这个问题是使用PyUIC将ui文件变成py文件后遇到的 解决办法: 改动1:把object改成QtWidgets.QMainWindow 改动2:增加__init__函数,函数结构如下: def __init__(self):super(Ui_MainWindow,self).…...

基于Java+SpringBoot+Vue前后端分离电商项目
晚间lucky为友友们送福利啦~🎁 Tips:有需要毕业设计指导的童鞋一定要认真看哦,文末有彩蛋。 一.项目介绍 该电商项目是一个简单、入门级的电商项目,是基于JavaSpringBootVue前后端分离项目。前端采用两套独立的系统分别完成项目…...
Rpc服务消费者(Rpc服务调用者)实现思路
Rpc服务消费者(Rpc服务调用者)实现思路 前面几节说到Rpc消费者主要通过UserServiceRPc_Stub这个protobuf帮我们生成的类来实现,上代码回顾一下 class UserServiceRpc_Stub : public UserServiceRpc {public:UserServiceRpc_Stub(::PROTOBUF…...

FANUC机器人实现2个RO输出信号互锁关联(互补)的具体方法
FANUC机器人实现2个RO输出信号互锁关联(互补)的具体方法 一般情况下,为了方便用户控制工装夹具上的电磁阀等控制工具,FANUC机器人出厂时给我们提供了8个RO输出信号,如下图所示,这8个RO信号可以各自单独使用。 那么,如果为了安全控制,需要将2个RO信号成对的进行安全互锁…...

权威认可|云畅科技再次入选中国信通院「高质量数字化转型产品及服务全景图」
7月27日,由中国信通院主办的2023数字生态发展大会暨中国信通院“铸基计划”年中会议在北京成功召开。 会上,中国信通院重磅发布了「高质量数字化转型产品及服务全景图(2023)」,云畅科技凭借其自研产品「万应低代码」在…...

爬虫小白-如何调试列表页链接与详情链接不一样并三种方式js逆向解决AES-ECB
目录 一、网站分析二、定位监听三、熟悉AES-ECB四、调试分析五、node运行js六、Python执行js 一、网站分析 三年前的案例,我的原始文章网站 ,如图我们直接点击标题进入到详情页,链接会发生跳转,且与我们在详情看到的链接…...

Ubuntu 离线部署的常见操作
Ubuntu 离线安装的常见操作 **说明:**很多情况下,生产环境都是离线环境,然而开发环境都是互联网的环境,因此部署的过程中需要构建离线安装包; 1. 下载但是不安装 # 例如使用 apt 下载 wireshark 安装包 sudo apt download wireshark # 下载…...

什么是多运行时架构?
服务化演进中的问题 自从数年前微服务的概念被提出,到现在基本成了技术架构的标配。微服务的场景下衍生出了对分布式能力的大量需求:各服务之间需要相互协作和通信,以及共享状态等等,因此就有了各种中间件来为业务服务提供这种分…...

【大模型RAG】拍照搜题技术架构速览:三层管道、两级检索、兜底大模型
摘要 拍照搜题系统采用“三层管道(多模态 OCR → 语义检索 → 答案渲染)、两级检索(倒排 BM25 向量 HNSW)并以大语言模型兜底”的整体框架: 多模态 OCR 层 将题目图片经过超分、去噪、倾斜校正后,分别用…...
【Linux】shell脚本忽略错误继续执行
在 shell 脚本中,可以使用 set -e 命令来设置脚本在遇到错误时退出执行。如果你希望脚本忽略错误并继续执行,可以在脚本开头添加 set e 命令来取消该设置。 举例1 #!/bin/bash# 取消 set -e 的设置 set e# 执行命令,并忽略错误 rm somefile…...

Docker 运行 Kafka 带 SASL 认证教程
Docker 运行 Kafka 带 SASL 认证教程 Docker 运行 Kafka 带 SASL 认证教程一、说明二、环境准备三、编写 Docker Compose 和 jaas文件docker-compose.yml代码说明:server_jaas.conf 四、启动服务五、验证服务六、连接kafka服务七、总结 Docker 运行 Kafka 带 SASL 认…...

【网络安全产品大调研系列】2. 体验漏洞扫描
前言 2023 年漏洞扫描服务市场规模预计为 3.06(十亿美元)。漏洞扫描服务市场行业预计将从 2024 年的 3.48(十亿美元)增长到 2032 年的 9.54(十亿美元)。预测期内漏洞扫描服务市场 CAGR(增长率&…...
镜像里切换为普通用户
如果你登录远程虚拟机默认就是 root 用户,但你不希望用 root 权限运行 ns-3(这是对的,ns3 工具会拒绝 root),你可以按以下方法创建一个 非 root 用户账号 并切换到它运行 ns-3。 一次性解决方案:创建非 roo…...

基于Docker Compose部署Java微服务项目
一. 创建根项目 根项目(父项目)主要用于依赖管理 一些需要注意的点: 打包方式需要为 pom<modules>里需要注册子模块不要引入maven的打包插件,否则打包时会出问题 <?xml version"1.0" encoding"UTF-8…...

EtherNet/IP转DeviceNet协议网关详解
一,设备主要功能 疆鸿智能JH-DVN-EIP本产品是自主研发的一款EtherNet/IP从站功能的通讯网关。该产品主要功能是连接DeviceNet总线和EtherNet/IP网络,本网关连接到EtherNet/IP总线中做为从站使用,连接到DeviceNet总线中做为从站使用。 在自动…...
css3笔记 (1) 自用
outline: none 用于移除元素获得焦点时默认的轮廓线 broder:0 用于移除边框 font-size:0 用于设置字体不显示 list-style: none 消除<li> 标签默认样式 margin: xx auto 版心居中 width:100% 通栏 vertical-align 作用于行内元素 / 表格单元格ÿ…...

html-<abbr> 缩写或首字母缩略词
定义与作用 <abbr> 标签用于表示缩写或首字母缩略词,它可以帮助用户更好地理解缩写的含义,尤其是对于那些不熟悉该缩写的用户。 title 属性的内容提供了缩写的详细说明。当用户将鼠标悬停在缩写上时,会显示一个提示框。 示例&#x…...
【无标题】路径问题的革命性重构:基于二维拓扑收缩色动力学模型的零点隧穿理论
路径问题的革命性重构:基于二维拓扑收缩色动力学模型的零点隧穿理论 一、传统路径模型的根本缺陷 在经典正方形路径问题中(图1): mermaid graph LR A((A)) --- B((B)) B --- C((C)) C --- D((D)) D --- A A -.- C[无直接路径] B -…...