当前位置: 首页 > news >正文

线性神经网络——softmax 回归随笔【深度学习】【PyTorch】【d2l】

文章目录

    • 3.2、softmax 回归
      • 3.2.1、softmax运算
      • 3.2.2、交叉熵损失函数
      • 3.2.3、PyTorch 从零实现 softmax 回归
      • 3.2.4、简单实现 softmax 回归

在这里插入图片描述

3.2、softmax 回归

3.2.1、softmax运算

在这里插入图片描述

softmax 函数是一种常用的激活函数,用于将实数向量转换为概率分布向量。它在多类别分类问题中起到重要的作用,并与交叉熵损失函数结合使用。

y ^ = s o f t m a x ( o ) 其中     y ^ i = e x p ( o j ) ∑ k e x p ( o k ) \hat{y} = softmax(o) \ \ \ \ \ 其中\ \ \ \ \hat{y}_i = \frac{exp(o_j)}{\sum_{k}exp(o_k)} y^=softmax(o)     其中    y^i=kexp(ok)exp(oj)

其中,O为小批量的未规范化的预测, Y ^ \hat{Y} Y^w为输出概率,是一个正确的概率分布【 ∑ y i = 1 \sum{y_i} =1 yi=1

3.2.2、交叉熵损失函数

通过测量给定模型编码的比特位,来衡量两概率分布之间的差异,是分类问题中常用的 loss 函数。

H ( P , Q ) = − Σ P ( x ) ∗ l o g ( Q ( x ) ) H(P, Q) = -Σ P(x) * log(Q(x)) H(P,Q)=ΣP(x)log(Q(x))

真实概率分布是从哪里得知的?

真实标签的概率分布是由数据集中的标签信息提供的,通常使用单热编码表示。

softmax() 如何与交叉熵函数搭配的?

softmax 函数与交叉熵损失函数常用于多分类任务中。softmax 函数用于将模型输出转化为概率分布形式,交叉熵损失函数用于衡量模型输出概率分布与真实标签的差异,并通过优化算法来最小化损失函数,从而训练出更准确的分类模型。

3.2.3、PyTorch 从零实现 softmax 回归

(非完整代码)

#在 Notebook 中内嵌绘图
%matplotlib inline
import torch
import torchvision
from torch.utils import data
from torchvision import transforms
from d2l import torch as d2l#,将图形显示格式设置为 SVG 格式,以在 Notebook 中以矢量图形的形式显示图像。这有助于提高图像的清晰度和可缩放性。
d2l .use_svg_display()

在线下载数据集 Fashion-MNIST

#将图像数据转换为张量形式
trans = transforms.ToTensor()
mnist_train = torchvision.datasets.FashionMNIST(root="../data",train=True,transform=trans,download=True)
mnist_test = torchvision.datasets.FashionMNIST(root="../data",train=False,transform =trans,download=True)len(mnist_train),len(mnist_test)

绘图(略)

读取小批量数据集

batch_size = 256def get_dataloader_workers():"""使用4进程读取"""return 4train_iter = data.DataLoader(mnist_train,batch_size,shuffle=True,num_workers=get_dataloader_workers())
timer = d2l.Timer()
for X,y in train_iter:continue
print(f'{timer.stop():.2f}sec')

定义softmax操作

def softmax(X):X_exp = torch.exp(X)partition = X_exp.sum(1, keepdim=True)return X_exp / partition  # 这里应用了广播机制

定义损失函数

def cross_entropy(y_hat, y):return - torch.log(y_hat[range(len(y_hat)), y])cross_entropy(y_hat, y)

分类精度

def accuracy(y_hat, y):  #@save"""计算预测正确的数量"""if len(y_hat.shape) > 1 and y_hat.shape[1] > 1:y_hat = y_hat.argmax(axis=1)cmp = y_hat.type(y.dtype) == yreturn float(cmp.type(y.dtype).sum())

评估

def evaluate_accuracy(net, data_iter):  #@save"""计算在指定数据集上模型的精度"""if isinstance(net, torch.nn.Module):net.eval()  # 将模型设置为评估模式metric = Accumulator(2)  # 正确预测数、预测总数with torch.no_grad():for X, y in data_iter:metric.add(accuracy(net(X), y), y.numel())return metric[0] / metric[1]
class Accumulator:  #@save"""在n个变量上累加"""def __init__(self, n):self.data = [0.0] * ndef add(self, *args):self.data = [a + float(b) for a, b in zip(self.data, args)]def reset(self):self.data = [0.0] * len(self.data)def __getitem__(self, idx):return self.data[idx]

3.2.4、简单实现 softmax 回归

导入前面已下载数据集 Fashion-MNIST

import torch 
from torch import nn
from d2l import torch as d2lbatch_size =256
train_iter,test_iter = d2l.load_data_fashion_mnist(batch_size)

初始化模型

#nn.Flatten() 层的作用是将输入数据展平,将二维输入(如图像)转换为一维向量。因为线性层(nn.Linear)通常期望接收一维输入。
#nn.Linear(784,10) 将输入特征从 784 维降低到 10 维,用于图像分类问题中的 10 个类别的预测   784维向量->10维向量
net = nn.Sequential(nn.Flatten(),nn.Linear(784,10))def init_weights(m):if type(m) == nn.Linear:nn.init.normal_(m.weight,std=0.01)net.apply(init_weights);
#计算交叉熵损失函数,用于衡量模型预测与真实标签之间的差异。参数 reduction 控制了损失的计算方式。
#reduction='none' 表示不进行损失的降维或聚合操作,即返回每个样本的独立损失值。
loss = nn.CrossEntropyLoss(reduction='none')

优化算法

trainer = torch.optim.SGD(net.parameters(),lr=0.1)

训练

num_epochs = 10
d2l.train_ch3(net,train_iter,test_iter,loss,num_epochs,trainer)

相关文章:

线性神经网络——softmax 回归随笔【深度学习】【PyTorch】【d2l】

文章目录 3.2、softmax 回归3.2.1、softmax运算3.2.2、交叉熵损失函数3.2.3、PyTorch 从零实现 softmax 回归3.2.4、简单实现 softmax 回归 3.2、softmax 回归 3.2.1、softmax运算 softmax 函数是一种常用的激活函数,用于将实数向量转换为概率分布向量。它在多类别…...

【Nodejs】Node.js开发环境安装

1.版本介绍 在命令窗口中输入 node -v 可以查看版本 0.x 完全不技术 ES64.x 部分支持 ES6 特性5.x 部分支持ES6特性(比4.x多些),属于过渡产品,现在来说应该没有什么理由去用这个了6.x 支持98%的 ES6 特性8.x 支持 ES6 特性 2.No…...

梅尔频谱(Mel spectrum)简介及Python实现

梅尔频谱(Mel spectrum)简介及Python实现 1. 梅尔频谱(Mel spectrum)简介2. Python可视化测试3.频谱可视化3.1 Mel 频谱可视化3.2 STFT spectrum参考文献资料1. 梅尔频谱(Mel spectrum)简介 在信号处理上,声信号(噪声信号)是一种重要的传感监测手段。对于语音分类任务…...

【数据结构】实验六:队列

实验六 队列 一、实验目的与要求 1)熟悉C/C语言(或其他编程语言)的集成开发环境; 2)通过本实验加深对队列的理解,熟悉基本操作; 3) 结合具体的问题分析算法时间复杂度。 二、…...

【Linux线程】第一章||理解线程概念+创建一个线程(附代码加讲解)

线程概念 🌵什么是线程🌲线程和进程的关系🎄线程有以下特点:🌳 线程的优点🌴 线程的缺点🌱线程异常🌿线程用途 ☘️手动创建一个进程🍀运行 🌵什么是线程 在L…...

Android进阶之微信扫码登录

遇到新需求要搭建微信扫码登录功能,这篇文章是随着我的编码过程一并写的,希望能够帮助有需求的人和以后再次用到此功能的自己。 首先想到的就是百度各种文章,当然去开发者平台申请AppID和密钥是必不可少的,等注册好发现需要创建应用以及审核(要官网,流程图及其他信息),想着先写…...

macOS Monterey 12.6.8 (21G725) Boot ISO 原版可引导镜像

macOS Monterey 12.6.8 (21G725) Boot ISO 原版可引导镜像 本站下载的 macOS 软件包,既可以拖拽到 Applications(应用程序)下直接安装,也可以制作启动 U 盘安装,或者在虚拟机中启动安装。另外也支持在 Windows 和 Lin…...

Unity自定义后处理——用偏导数求图片颜色边缘

大家好,我是阿赵。   继续介绍屏幕后处理效果的做法。这次介绍一下用偏导数求图形边缘的技术。 一、原理介绍 先来看例子吧。   这个例子看起来好像是要给模型描边。之前其实也介绍过很多描边的方法,比如沿着法线方向放大模型,或者用Ndo…...

本地Git仓库和GitHub仓库SSH传输

SSH创建命令解释 ssh-keygen 用于创建密钥的程序 -m PEM 将密钥的格式设为 PEM -t rsa 要创建的密钥类型,本例中为 RSA 格式 -b 4096 密钥的位数,本例中为 4096 -C “azureusermyserver” 追加到公钥文件末尾以便于识别的注释。 通常以电子邮件地址…...

【C++11】——右值引用、移动语义

目录 1. 基本概念 1.1 左值与左值引用 1.2 右值和右值引用 1.3 左值引用与右值引用 2. 右值引用实用场景和意义 2.1 左值引用的使用场景 2.2 左值引用的短板 2.3 右值引用和移动语义 2.3.1 移动构造 2.3.2 移动赋值 2.3.3 编译器做的优化 2.3.4 总结 2.4 右值引用…...

消息服务概述

消息服务的作用: 在多数应用尤其是分布式系统中,消息服务是不可或缺的重要部分,它使用起来比较简单,同时解决了不少难题,例如异步处理、应用解耦、流量削锋、分布式事务管理等,使用消息服务可以实现一个高…...

【Spring Boot】Web开发 — 数据验证

Web开发 — 数据验证 对于应用系统而言,任何客户端传入的数据都不是绝对安全有效的,这就要求我们在服务端接收到数据时也对数据的有效性进行验证,以确保传入的数据安全正确。接下来介绍Spring Boot是如何实现数据验证的。 1.Hibernate Vali…...

技术分享 | App常见bug解析

功能Bug 内容显示错误 前端页面展示的内容有误。 这种错误的产生有两种可能 1、前端代码写的文案错误 2、接口返回值错误 功能错误 功能错误是在测试过程中最常见的类型之一,也就是产品的功能没有实现。比如图中的公众号登录不成功的问题。 界面展示错乱 产…...

树莓派Pico|RP2040|使用SWD进行调试|构建 “Hello World“ debug版本

文章目录 使用SWD进行调试构建 "Hello World" debug版本安装 GDB使用 GDB 和 OpenOCD 来 debug Hello World TIP重要提示 使用SWD进行调试 基于rp2040的板上的SWD端口重置,加载和运行代码,如树莓派Pico可用于交互式调试已加载的程序。这包括:…...

Ubuntu18.04 下配置Clion

配置Clion 安装gcc、g、make Ubuntu中用到的编译工具是gcc©,g(C),make(连接)。因此只需安装对应的工具包即可。Ubuntu下使用命令安装这些包: (1)安装gcc sudo apt install gcc&am…...

数据库管理-第九十四期 19c OCM之路-第四堂(02)(20230725)

第九十四期 19c OCM之路-第四堂(02)(20230725) 第四堂继续! 考点3:SQL statement tuning SQL语句调优 收集Schema统计信息 exec dbms_stats.gather_schems_stats(HR);开启制定表索引监控 create index…...

以智慧监测模式守护燃气安全 ,汉威科技“传感芯”凸显智慧力

城市燃气工程作为城市基建的重要组成部分,与城市居民生活、工业生产紧密相关。提升城市燃气服务质量和安全水平,也一直是政府和民众关注的大事。然而,近年来居民住宅、餐饮等工商业场所燃气事故频发,时刻敲响的警钟也折射出我国在…...

【阅读笔记】一种暗通道优先的快速自动白平衡算法

解决问题: 自动白平衡算法中存在白色区域检测错误导致白平衡失效的问题,作者提出了一种基于暗通道优先的白平衡算法。 算法思想: 图像中白色区域或者高饱和度区域的光线透射率较低,根据以上特性利用暗通道法计算图像中白色区域。 算法概述: 作者使用何凯明提出的基于暗…...

OpenStack之云主机管理

一&#xff09;必备知识 1.云主机与快照管理 a-云主机管理 云主机管理是OpenStack云计算平台的核心功能&#xff0c;通常&#xff0c;云主机的管理包括创建、删除、查询等。可使用以下命令对OpenStack的云主机进行管理&#xff1a; openstack server <操作><云主机…...

Linux系列---【Ubuntu 20.04安装KVM】

Ubuntu 20.04安装KVM 一、安装kvm 1.安装kvm sudo apt install qemu-kvm libvirt-daemon-system libvirt-clients bridge-utils 2. 将当前用户添加至libvirt 、 kvm组 sudo adduser $USER libvirt sudo adduser $USER kvm 3.验证安装 virsh list --all 4.启动libvert sudo syst…...

零门槛NAS搭建:WinNAS如何让普通电脑秒变私有云?

一、核心优势&#xff1a;专为Windows用户设计的极简NAS WinNAS由深圳耘想存储科技开发&#xff0c;是一款收费低廉但功能全面的Windows NAS工具&#xff0c;主打“无学习成本部署” 。与其他NAS软件相比&#xff0c;其优势在于&#xff1a; 无需硬件改造&#xff1a;将任意W…...

linux之kylin系统nginx的安装

一、nginx的作用 1.可做高性能的web服务器 直接处理静态资源&#xff08;HTML/CSS/图片等&#xff09;&#xff0c;响应速度远超传统服务器类似apache支持高并发连接 2.反向代理服务器 隐藏后端服务器IP地址&#xff0c;提高安全性 3.负载均衡服务器 支持多种策略分发流量…...

Cesium1.95中高性能加载1500个点

一、基本方式&#xff1a; 图标使用.png比.svg性能要好 <template><div id"cesiumContainer"></div><div class"toolbar"><button id"resetButton">重新生成点</button><span id"countDisplay&qu…...

Linux相关概念和易错知识点(42)(TCP的连接管理、可靠性、面临复杂网络的处理)

目录 1.TCP的连接管理机制&#xff08;1&#xff09;三次握手①握手过程②对握手过程的理解 &#xff08;2&#xff09;四次挥手&#xff08;3&#xff09;握手和挥手的触发&#xff08;4&#xff09;状态切换①挥手过程中状态的切换②握手过程中状态的切换 2.TCP的可靠性&…...

多模态商品数据接口:融合图像、语音与文字的下一代商品详情体验

一、多模态商品数据接口的技术架构 &#xff08;一&#xff09;多模态数据融合引擎 跨模态语义对齐 通过Transformer架构实现图像、语音、文字的语义关联。例如&#xff0c;当用户上传一张“蓝色连衣裙”的图片时&#xff0c;接口可自动提取图像中的颜色&#xff08;RGB值&…...

第25节 Node.js 断言测试

Node.js的assert模块主要用于编写程序的单元测试时使用&#xff0c;通过断言可以提早发现和排查出错误。 稳定性: 5 - 锁定 这个模块可用于应用的单元测试&#xff0c;通过 require(assert) 可以使用这个模块。 assert.fail(actual, expected, message, operator) 使用参数…...

【2025年】解决Burpsuite抓不到https包的问题

环境&#xff1a;windows11 burpsuite:2025.5 在抓取https网站时&#xff0c;burpsuite抓取不到https数据包&#xff0c;只显示&#xff1a; 解决该问题只需如下三个步骤&#xff1a; 1、浏览器中访问 http://burp 2、下载 CA certificate 证书 3、在设置--隐私与安全--…...

2025盘古石杯决赛【手机取证】

前言 第三届盘古石杯国际电子数据取证大赛决赛 最后一题没有解出来&#xff0c;实在找不到&#xff0c;希望有大佬教一下我。 还有就会议时间&#xff0c;我感觉不是图片时间&#xff0c;因为在电脑看到是其他时间用老会议系统开的会。 手机取证 1、分析鸿蒙手机检材&#x…...

06 Deep learning神经网络编程基础 激活函数 --吴恩达

深度学习激活函数详解 一、核心作用 引入非线性:使神经网络可学习复杂模式控制输出范围:如Sigmoid将输出限制在(0,1)梯度传递:影响反向传播的稳定性二、常见类型及数学表达 Sigmoid σ ( x ) = 1 1 +...

【HTTP三个基础问题】

面试官您好&#xff01;HTTP是超文本传输协议&#xff0c;是互联网上客户端和服务器之间传输超文本数据&#xff08;比如文字、图片、音频、视频等&#xff09;的核心协议&#xff0c;当前互联网应用最广泛的版本是HTTP1.1&#xff0c;它基于经典的C/S模型&#xff0c;也就是客…...