当前位置: 首页 > news >正文

机器学习深度学习——softmax回归的简洁实现

👨‍🎓作者简介:一位即将上大四,正专攻机器学习的保研er
🌌上期文章:机器学习&&深度学习——softmax回归从零开始实现
📚订阅专栏:机器学习&&深度学习
希望文章对你们有所帮助

继续使用Fashion-MNIST数据集,并保持批量大小为256:

import torch
from torch import nn
from d2l import torch as d2lbatch_size = 256
train_iter, test_iter = d2l.load_data_fashion_mnist(batch_size)

softmax回归的简洁实现

  • 初始化模型参数
  • 重新审视softmax的实现
    • 数学推导
    • 交叉熵函数
  • 优化算法
  • 训练

初始化模型参数

softmax的输出层是一个全连接层,因此,为了实现模型,我们只需要在Sequential中添加一个带有10个输出的全连接层。当然这里的Sequential并不是必要的,但是他是深度模型的基础。我们仍旧以均值为0,标准差为0.01来随机初始化权重。

# pytorch不会隐式地调整输入的形状
# 因此在线性层前就定义了展平层flatten,来调整网络输入的形状
net = nn.Sequential(nn.Flatten(), nn.Linear(784, 10))def init_weights(m):if type(m) == nn.Linear:nn.init.normal_(m.weight, std=0.01)net.apply(init_weights)  # 给net每一层跑一次init_weights函数

重新审视softmax的实现

数学推导

在之前的例子里,我们计算了模型的输出,然后将此输出送入交叉熵损失。看似合理,但是指数级计算可能会造成数值的稳定性问题。
回想一下之前的softmax函数:
y ^ j = e x p ( o j ) ∑ k e x p ( o k ) 其中 y ^ j 是预测的概率分布, o j 是未规范化的第 j 个元素 \hat{y}_j=\frac{exp(o_j)}{\sum_kexp(o_k)}\\ 其中\hat{y}_j是预测的概率分布,o_j是未规范化的第j个元素 y^j=kexp(ok)exp(oj)其中y^j是预测的概率分布,oj是未规范化的第j个元素
由于o中的一些数值会非常大,所以可能会让其指数值上溢,使得分子或分母变成inf,最后得到的预测值可能变成的0、inf或者nan。此时我们无法得到一个明确的交叉熵值。
提出解决这个问题的一个技巧:在继续softmax计算之前,先从所有的o中减去max(o),修改softmax函数的构造且不改变其返回值:
y ^ j = e x p ( o j − m a x ( o k ) ) e x p ( m a x ( o k ) ) ∑ k e x p ( o j − m a x ( o k ) ) e x p ( m a x ( o k ) ) \hat{y}_j=\frac{exp(o_j-max(o_k))exp(max(o_k))}{\sum_kexp(o_j-max(o_k))exp(max(o_k))} y^j=kexp(ojmax(ok))exp(max(ok))exp(ojmax(ok))exp(max(ok))
这样操作以后,可能会使得一些分子的exp(o-max(o))有接近0的值,即为下溢。这些值可能会四舍五入为0,这样就会使得预测值为0,那么此时要是取对数以后就会变为-inf。要是这样反向传播几步,我们可能会发现自己屏幕有一堆的nan。
尽管我们需要计算指数函数,但是我们最终会在计算交叉熵损失的时候会取他们的对数。尽管通过将softmax和交叉熵结合在一起,可以避免反向传播过程中可能会困扰我们的数值稳定性问题。如下面的式子:
l o g ( y ^ j ) = l o g ( e x p ( o j − m a x ( o k ) ) ∑ k e x p ( o k − m a x ( o k ) ) ) = l o g ( e x p ( o j − m a x ( o k ) ) ) − l o g ( ∑ k e x p ( o k − m a x ( o k ) ) ) = o j − m a x ( o k ) − l o g ( ∑ k e x p ( o k − m a x ( o k ) ) ) log(\hat{y}_j)=log(\frac{exp(o_j-max(o_k))}{\sum_kexp(o_k-max(o_k))})\\ =log(exp(o_j-max(o_k)))-log(\sum_kexp(o_k-max(o_k)))\\ =o_j-max(o_k)-log(\sum_kexp(o_k-max(o_k))) log(y^j)=log(kexp(okmax(ok))exp(ojmax(ok)))=log(exp(ojmax(ok)))log(kexp(okmax(ok)))=ojmax(ok)log(kexp(okmax(ok)))
通过上式,我们避免了计算单独的exp(o-max(o)),而是直接使用o-max(o)。
因此,我们计算交叉熵函数的时候,传递的不是未规范化的预测o,而不是softmax。
但是我们也希望保留传统的softmax函数,以备我们要评估通过模型输出的概率。

交叉熵函数

在这里介绍一下交叉熵函数,以用于上面推导所需的需求:

torch.nn.CrossEntropyLoss(weight=None,ignore_index=-100,reduction='mean')

交叉熵函数是将LogSoftMax和NLLLoss集成到一个类中,通常用于多分类问题。其参数使用情况:

ignore_index:指定被忽略且对输入梯度没有贡献的目标值。
reduction:string类型的可选项,可在[none,mean,sum]中选。none表示不降维,返回和target一样的形状;mean表示对一个batch的损失求均值;sum表示对一个batch的损失求和。
weight:是一个一维的张量,包含n个元素,分别代表n类的权重,在训练样本不均衡时很有用,默认为None:
(1)当weight=None时,损失函数计算方式为
loss(x,class)=-log(exp(x[class])/Σexp(x[j]))=-x[class]+log(Σexp(x[j])
(2)当weight被指定时,损失函数计算方式为:
loss(x,class)=weight[class]×(-x[class]+log(Σexp(x[j]))

# 在交叉熵损失函数中传递未归一化的预测,并同时计算softmax及其导数
loss = nn.CrossEntropyLoss(reduction='none')

优化算法

# 优化算法
trainer = torch.optim.SGD(net.parameters(), lr=0.1)

训练

调用之前定义的训练函数来训练模型:

# 调用之前的训练函数来训练模型
num_epochs = 10
d2l.train_ch3(net, train_iter, test_iter, loss, num_epochs, trainer)
d2l.plt.show()

在这里插入图片描述

相关文章:

机器学习深度学习——softmax回归的简洁实现

👨‍🎓作者简介:一位即将上大四,正专攻机器学习的保研er 🌌上期文章:机器学习&&深度学习——softmax回归从零开始实现 📚订阅专栏:机器学习&&深度学习 希望文章对你…...

CPU利用率过高解决思路

文章目录 问题场景问题定位问题解决 本文参考: Linux服务器之CPU过高解决思路_linux cpu温度过高_Jeremy_Lee123的博客-CSDN博客 Java程序员必备:jstack命令解析 - 掘金 (juejin.cn) 重点问题!CPU利用率过高排查思路|原创 (qq.…...

Redis(三)—— Redis基本的事务操作、Redis实现乐观锁

一、Redis基本的事务操作 首先声明: redis的单条命令是保证原子性的(回想一下setnx k1 v1 k5 v5命令如果k1已经存在,那么k5也会设置失败)但是redis的事务不保证原子性!见下面“1.2 某条命令有错怎么办?”…...

SQLI_LABS攻击

目录 Less1 首先来爆字段 联合注入 判断注入点 爆数据库名 爆破表名 information_schema information_schmea.tables group_concat() 爆破列名 information_schema.columns 爆值 SQLMAP Less-2 -4 Less -5 布尔 数据库 表名 字段名 爆破值 SQLMAP Less-6 …...

如何查看 Chrome 网站有没有前端 JavaScript 报错?

您可以按照以下步骤在Chrome中查看网站是否存在前端JavaScript报错: 步骤1:打开Chrome浏览器并访问网站 首先,打开Chrome浏览器并访问您想要检查JavaScript报错的网站。 步骤2:打开开发者工具 在Chrome浏览器中,按…...

JS前端读取本地上传的File文件对象内容(包括Base64、text、JSON、Blob、ArrayBuffer等类型文件)

读取base64图片File file2Base64Image(file, cb) {const reader new FileReader();reader.readAsDataURL(file);reader.onload function (e) {cb && cb(e.target.result);//即为base64结果}; }, 读取text、JSON文件File readText(file, { onloadend } {}) {const re…...

【项目方案】OpenAI流式请求实现方案

文章目录 实现目的效果比对非stream模式stream模式实现方案方案思路总体描述前端方案对比event-source-polyfill代码示例前端实现遇到的问题与解决方法后端参考资料时序图关键代码示例后端实现时遇到的问题与解决方法实现目的 stream是OpenAI API中的一个参数,用于控制请求的…...

华为数通HCIP-IP组播基础

点到点业务:比如FTP,WEB业务,此类业务主要特点是不同的用户有不同的需求,比如用户A需要下载资料A,用户B需要下载资料B。此类业务一般由单播承载,服务器对于不同用户发送不同的点到点数据流。 ospf、isis…...

STM32 SPI学习

SPI 串行外设设备接口(Serial Peripheral Interface),是一种高速的,全双工,同步的通信总线。 SCK时钟信号由主机发出。 SPI接口主要应用在存储芯片。 SPI相关引脚:MOSI(输出数据线&#xff…...

分布式缓存与数据库的一致性记录

用户更新数据库,需要再去更新redis缓存,否则会造成缓存与数据库数据不一致 一致性的两种方法 1). 双写模式 更新完数据库之后,更新redis缓存数据 问题: 因为请求时间的问题,造成缓存数据不是最新的 数据。 原因:A先修…...

vue3的语法

main.js中写发生变化,并不兼容vue2的写法 //vue3 import { createApp } from vue import ./style.css import App from ./App.vuecreateApp(App).mount(#app)//vue2 import Vue from vue import ./style.css import App from ./App.vueconst vm new Vue({render:h…...

【git合并分支自定义提交消息】

开发分支 dev主分支 master 需求 dev分支开发完后合并到master分支自定义提交信息 通过 git merge dev --squash --no-commit此命令会拉取dev分支代码到当前分支,并不会自动提交,可以自己修改提交信息...

AttributeError: module ‘PyQt5.QtGui‘ has no attribute ‘QMainWindow‘

场景描述: 这个问题是使用PyUIC将ui文件变成py文件后遇到的 解决办法: 改动1:把object改成QtWidgets.QMainWindow 改动2:增加__init__函数,函数结构如下: def __init__(self):super(Ui_MainWindow,self).…...

基于Java+SpringBoot+Vue前后端分离电商项目

晚间lucky为友友们送福利啦~🎁 Tips:有需要毕业设计指导的童鞋一定要认真看哦,文末有彩蛋。 一.项目介绍 该电商项目是一个简单、入门级的电商项目,是基于JavaSpringBootVue前后端分离项目。前端采用两套独立的系统分别完成项目…...

Rpc服务消费者(Rpc服务调用者)实现思路

Rpc服务消费者(Rpc服务调用者)实现思路 前面几节说到Rpc消费者主要通过UserServiceRPc_Stub这个protobuf帮我们生成的类来实现,上代码回顾一下 class UserServiceRpc_Stub : public UserServiceRpc {public:UserServiceRpc_Stub(::PROTOBUF…...

FANUC机器人实现2个RO输出信号互锁关联(互补)的具体方法

FANUC机器人实现2个RO输出信号互锁关联(互补)的具体方法 一般情况下,为了方便用户控制工装夹具上的电磁阀等控制工具,FANUC机器人出厂时给我们提供了8个RO输出信号,如下图所示,这8个RO信号可以各自单独使用。 那么,如果为了安全控制,需要将2个RO信号成对的进行安全互锁…...

权威认可|云畅科技再次入选中国信通院「高质量数字化转型产品及服务全景图」

7月27日,由中国信通院主办的2023数字生态发展大会暨中国信通院“铸基计划”年中会议在北京成功召开。 会上,中国信通院重磅发布了「高质量数字化转型产品及服务全景图(2023)」,云畅科技凭借其自研产品「万应低代码」在…...

爬虫小白-如何调试列表页链接与详情链接不一样并三种方式js逆向解决AES-ECB

目录 一、网站分析二、定位监听三、熟悉AES-ECB四、调试分析五、node运行js六、Python执行js 一、网站分析 三年前的案例,我的原始文章网站 ,如图我们直接点击标题进入到详情页,链接会发生跳转,且与我们在详情看到的链接&#xf…...

Ubuntu 离线部署的常见操作

Ubuntu 离线安装的常见操作 **说明:**很多情况下,生产环境都是离线环境,然而开发环境都是互联网的环境,因此部署的过程中需要构建离线安装包; 1. 下载但是不安装 # 例如使用 apt 下载 wireshark 安装包 sudo apt download wireshark # 下载…...

什么是多运行时架构?

服务化演进中的问题 自从数年前微服务的概念被提出,到现在基本成了技术架构的标配。微服务的场景下衍生出了对分布式能力的大量需求:各服务之间需要相互协作和通信,以及共享状态等等,因此就有了各种中间件来为业务服务提供这种分…...

MPNet:旋转机械轻量化故障诊断模型详解python代码复现

目录 一、问题背景与挑战 二、MPNet核心架构 2.1 多分支特征融合模块(MBFM) 2.2 残差注意力金字塔模块(RAPM) 2.2.1 空间金字塔注意力(SPA) 2.2.2 金字塔残差块(PRBlock) 2.3 分类器设计 三、关键技术突破 3.1 多尺度特征融合 3.2 轻量化设计策略 3.3 抗噪声…...

mongodb源码分析session执行handleRequest命令find过程

mongo/transport/service_state_machine.cpp已经分析startSession创建ASIOSession过程,并且验证connection是否超过限制ASIOSession和connection是循环接受客户端命令,把数据流转换成Message,状态转变流程是:State::Created 》 St…...

可靠性+灵活性:电力载波技术在楼宇自控中的核心价值

可靠性灵活性:电力载波技术在楼宇自控中的核心价值 在智能楼宇的自动化控制中,电力载波技术(PLC)凭借其独特的优势,正成为构建高效、稳定、灵活系统的核心解决方案。它利用现有电力线路传输数据,无需额外布…...

STM32F4基本定时器使用和原理详解

STM32F4基本定时器使用和原理详解 前言如何确定定时器挂载在哪条时钟线上配置及使用方法参数配置PrescalerCounter ModeCounter Periodauto-reload preloadTrigger Event Selection 中断配置生成的代码及使用方法初始化代码基本定时器触发DCA或者ADC的代码讲解中断代码定时启动…...

江苏艾立泰跨国资源接力:废料变黄金的绿色供应链革命

在华东塑料包装行业面临限塑令深度调整的背景下,江苏艾立泰以一场跨国资源接力的创新实践,重新定义了绿色供应链的边界。 跨国回收网络:废料变黄金的全球棋局 艾立泰在欧洲、东南亚建立再生塑料回收点,将海外废弃包装箱通过标准…...

【配置 YOLOX 用于按目录分类的图片数据集】

现在的图标点选越来越多,如何一步解决,采用 YOLOX 目标检测模式则可以轻松解决 要在 YOLOX 中使用按目录分类的图片数据集(每个目录代表一个类别,目录下是该类别的所有图片),你需要进行以下配置步骤&#x…...

【HTML-16】深入理解HTML中的块元素与行内元素

HTML元素根据其显示特性可以分为两大类:块元素(Block-level Elements)和行内元素(Inline Elements)。理解这两者的区别对于构建良好的网页布局至关重要。本文将全面解析这两种元素的特性、区别以及实际应用场景。 1. 块元素(Block-level Elements) 1.1 基本特性 …...

AI书签管理工具开发全记录(十九):嵌入资源处理

1.前言 📝 在上一篇文章中,我们完成了书签的导入导出功能。本篇文章我们研究如何处理嵌入资源,方便后续将资源打包到一个可执行文件中。 2.embed介绍 🎯 Go 1.16 引入了革命性的 embed 包,彻底改变了静态资源管理的…...

使用 Streamlit 构建支持主流大模型与 Ollama 的轻量级统一平台

🎯 使用 Streamlit 构建支持主流大模型与 Ollama 的轻量级统一平台 📌 项目背景 随着大语言模型(LLM)的广泛应用,开发者常面临多个挑战: 各大模型(OpenAI、Claude、Gemini、Ollama)接口风格不统一;缺乏一个统一平台进行模型调用与测试;本地模型 Ollama 的集成与前…...

2023赣州旅游投资集团

单选题 1.“不登高山,不知天之高也;不临深溪,不知地之厚也。”这句话说明_____。 A、人的意识具有创造性 B、人的认识是独立于实践之外的 C、实践在认识过程中具有决定作用 D、人的一切知识都是从直接经验中获得的 参考答案: C 本题解…...