线性神经网络——softmax 回归随笔【深度学习】【PyTorch】【d2l】
文章目录
- 3.2、softmax 回归
- 3.2.1、softmax运算
- 3.2.2、交叉熵损失函数
- 3.2.3、PyTorch 从零实现 softmax 回归
- 3.2.4、简单实现 softmax 回归
3.2、softmax 回归
3.2.1、softmax运算

softmax 函数是一种常用的激活函数,用于将实数向量转换为概率分布向量。它在多类别分类问题中起到重要的作用,并与交叉熵损失函数结合使用。
y ^ = s o f t m a x ( o ) 其中 y ^ i = e x p ( o j ) ∑ k e x p ( o k ) \hat{y} = softmax(o) \ \ \ \ \ 其中\ \ \ \ \hat{y}_i = \frac{exp(o_j)}{\sum_{k}exp(o_k)} y^=softmax(o) 其中 y^i=∑kexp(ok)exp(oj)
其中,O为小批量的未规范化的预测, Y ^ \hat{Y} Y^w为输出概率,是一个正确的概率分布【 ∑ y i = 1 \sum{y_i} =1 ∑yi=1 】
3.2.2、交叉熵损失函数
通过测量给定模型编码的比特位,来衡量两概率分布之间的差异,是分类问题中常用的 loss 函数。
H ( P , Q ) = − Σ P ( x ) ∗ l o g ( Q ( x ) ) H(P, Q) = -Σ P(x) * log(Q(x)) H(P,Q)=−ΣP(x)∗log(Q(x))
真实概率分布是从哪里得知的?
真实标签的概率分布是由数据集中的标签信息提供的,通常使用单热编码表示。
softmax() 如何与交叉熵函数搭配的?
softmax 函数与交叉熵损失函数常用于多分类任务中。softmax 函数用于将模型输出转化为概率分布形式,交叉熵损失函数用于衡量模型输出概率分布与真实标签的差异,并通过优化算法来最小化损失函数,从而训练出更准确的分类模型。
3.2.3、PyTorch 从零实现 softmax 回归
(非完整代码)
#在 Notebook 中内嵌绘图
%matplotlib inline
import torch
import torchvision
from torch.utils import data
from torchvision import transforms
from d2l import torch as d2l#,将图形显示格式设置为 SVG 格式,以在 Notebook 中以矢量图形的形式显示图像。这有助于提高图像的清晰度和可缩放性。
d2l .use_svg_display()
在线下载数据集 Fashion-MNIST
#将图像数据转换为张量形式
trans = transforms.ToTensor()
mnist_train = torchvision.datasets.FashionMNIST(root="../data",train=True,transform=trans,download=True)
mnist_test = torchvision.datasets.FashionMNIST(root="../data",train=False,transform =trans,download=True)len(mnist_train),len(mnist_test)
绘图(略)
读取小批量数据集
batch_size = 256def get_dataloader_workers():"""使用4进程读取"""return 4train_iter = data.DataLoader(mnist_train,batch_size,shuffle=True,num_workers=get_dataloader_workers())
timer = d2l.Timer()
for X,y in train_iter:continue
print(f'{timer.stop():.2f}sec')
定义softmax操作
def softmax(X):X_exp = torch.exp(X)partition = X_exp.sum(1, keepdim=True)return X_exp / partition # 这里应用了广播机制
定义损失函数
def cross_entropy(y_hat, y):return - torch.log(y_hat[range(len(y_hat)), y])cross_entropy(y_hat, y)
分类精度
def accuracy(y_hat, y): #@save"""计算预测正确的数量"""if len(y_hat.shape) > 1 and y_hat.shape[1] > 1:y_hat = y_hat.argmax(axis=1)cmp = y_hat.type(y.dtype) == yreturn float(cmp.type(y.dtype).sum())
评估
def evaluate_accuracy(net, data_iter): #@save"""计算在指定数据集上模型的精度"""if isinstance(net, torch.nn.Module):net.eval() # 将模型设置为评估模式metric = Accumulator(2) # 正确预测数、预测总数with torch.no_grad():for X, y in data_iter:metric.add(accuracy(net(X), y), y.numel())return metric[0] / metric[1]
class Accumulator: #@save"""在n个变量上累加"""def __init__(self, n):self.data = [0.0] * ndef add(self, *args):self.data = [a + float(b) for a, b in zip(self.data, args)]def reset(self):self.data = [0.0] * len(self.data)def __getitem__(self, idx):return self.data[idx]
3.2.4、简单实现 softmax 回归
导入前面已下载数据集 Fashion-MNIST
import torch
from torch import nn
from d2l import torch as d2lbatch_size =256
train_iter,test_iter = d2l.load_data_fashion_mnist(batch_size)
初始化模型
#nn.Flatten() 层的作用是将输入数据展平,将二维输入(如图像)转换为一维向量。因为线性层(nn.Linear)通常期望接收一维输入。
#nn.Linear(784,10) 将输入特征从 784 维降低到 10 维,用于图像分类问题中的 10 个类别的预测 784维向量->10维向量
net = nn.Sequential(nn.Flatten(),nn.Linear(784,10))def init_weights(m):if type(m) == nn.Linear:nn.init.normal_(m.weight,std=0.01)net.apply(init_weights);
#计算交叉熵损失函数,用于衡量模型预测与真实标签之间的差异。参数 reduction 控制了损失的计算方式。
#reduction='none' 表示不进行损失的降维或聚合操作,即返回每个样本的独立损失值。
loss = nn.CrossEntropyLoss(reduction='none')
优化算法
trainer = torch.optim.SGD(net.parameters(),lr=0.1)
训练
num_epochs = 10
d2l.train_ch3(net,train_iter,test_iter,loss,num_epochs,trainer)
相关文章:
线性神经网络——softmax 回归随笔【深度学习】【PyTorch】【d2l】
文章目录 3.2、softmax 回归3.2.1、softmax运算3.2.2、交叉熵损失函数3.2.3、PyTorch 从零实现 softmax 回归3.2.4、简单实现 softmax 回归 3.2、softmax 回归 3.2.1、softmax运算 softmax 函数是一种常用的激活函数,用于将实数向量转换为概率分布向量。它在多类别…...
【Nodejs】Node.js开发环境安装
1.版本介绍 在命令窗口中输入 node -v 可以查看版本 0.x 完全不技术 ES64.x 部分支持 ES6 特性5.x 部分支持ES6特性(比4.x多些),属于过渡产品,现在来说应该没有什么理由去用这个了6.x 支持98%的 ES6 特性8.x 支持 ES6 特性 2.No…...
梅尔频谱(Mel spectrum)简介及Python实现
梅尔频谱(Mel spectrum)简介及Python实现 1. 梅尔频谱(Mel spectrum)简介2. Python可视化测试3.频谱可视化3.1 Mel 频谱可视化3.2 STFT spectrum参考文献资料1. 梅尔频谱(Mel spectrum)简介 在信号处理上,声信号(噪声信号)是一种重要的传感监测手段。对于语音分类任务…...
【数据结构】实验六:队列
实验六 队列 一、实验目的与要求 1)熟悉C/C语言(或其他编程语言)的集成开发环境; 2)通过本实验加深对队列的理解,熟悉基本操作; 3) 结合具体的问题分析算法时间复杂度。 二、…...
【Linux线程】第一章||理解线程概念+创建一个线程(附代码加讲解)
线程概念 🌵什么是线程🌲线程和进程的关系🎄线程有以下特点:🌳 线程的优点🌴 线程的缺点🌱线程异常🌿线程用途 ☘️手动创建一个进程🍀运行 🌵什么是线程 在L…...
Android进阶之微信扫码登录
遇到新需求要搭建微信扫码登录功能,这篇文章是随着我的编码过程一并写的,希望能够帮助有需求的人和以后再次用到此功能的自己。 首先想到的就是百度各种文章,当然去开发者平台申请AppID和密钥是必不可少的,等注册好发现需要创建应用以及审核(要官网,流程图及其他信息),想着先写…...
macOS Monterey 12.6.8 (21G725) Boot ISO 原版可引导镜像
macOS Monterey 12.6.8 (21G725) Boot ISO 原版可引导镜像 本站下载的 macOS 软件包,既可以拖拽到 Applications(应用程序)下直接安装,也可以制作启动 U 盘安装,或者在虚拟机中启动安装。另外也支持在 Windows 和 Lin…...
Unity自定义后处理——用偏导数求图片颜色边缘
大家好,我是阿赵。 继续介绍屏幕后处理效果的做法。这次介绍一下用偏导数求图形边缘的技术。 一、原理介绍 先来看例子吧。 这个例子看起来好像是要给模型描边。之前其实也介绍过很多描边的方法,比如沿着法线方向放大模型,或者用Ndo…...
本地Git仓库和GitHub仓库SSH传输
SSH创建命令解释 ssh-keygen 用于创建密钥的程序 -m PEM 将密钥的格式设为 PEM -t rsa 要创建的密钥类型,本例中为 RSA 格式 -b 4096 密钥的位数,本例中为 4096 -C “azureusermyserver” 追加到公钥文件末尾以便于识别的注释。 通常以电子邮件地址…...
【C++11】——右值引用、移动语义
目录 1. 基本概念 1.1 左值与左值引用 1.2 右值和右值引用 1.3 左值引用与右值引用 2. 右值引用实用场景和意义 2.1 左值引用的使用场景 2.2 左值引用的短板 2.3 右值引用和移动语义 2.3.1 移动构造 2.3.2 移动赋值 2.3.3 编译器做的优化 2.3.4 总结 2.4 右值引用…...
消息服务概述
消息服务的作用: 在多数应用尤其是分布式系统中,消息服务是不可或缺的重要部分,它使用起来比较简单,同时解决了不少难题,例如异步处理、应用解耦、流量削锋、分布式事务管理等,使用消息服务可以实现一个高…...
【Spring Boot】Web开发 — 数据验证
Web开发 — 数据验证 对于应用系统而言,任何客户端传入的数据都不是绝对安全有效的,这就要求我们在服务端接收到数据时也对数据的有效性进行验证,以确保传入的数据安全正确。接下来介绍Spring Boot是如何实现数据验证的。 1.Hibernate Vali…...
技术分享 | App常见bug解析
功能Bug 内容显示错误 前端页面展示的内容有误。 这种错误的产生有两种可能 1、前端代码写的文案错误 2、接口返回值错误 功能错误 功能错误是在测试过程中最常见的类型之一,也就是产品的功能没有实现。比如图中的公众号登录不成功的问题。 界面展示错乱 产…...
树莓派Pico|RP2040|使用SWD进行调试|构建 “Hello World“ debug版本
文章目录 使用SWD进行调试构建 "Hello World" debug版本安装 GDB使用 GDB 和 OpenOCD 来 debug Hello World TIP重要提示 使用SWD进行调试 基于rp2040的板上的SWD端口重置,加载和运行代码,如树莓派Pico可用于交互式调试已加载的程序。这包括:…...
Ubuntu18.04 下配置Clion
配置Clion 安装gcc、g、make Ubuntu中用到的编译工具是gcc©,g(C),make(连接)。因此只需安装对应的工具包即可。Ubuntu下使用命令安装这些包: (1)安装gcc sudo apt install gcc&am…...
数据库管理-第九十四期 19c OCM之路-第四堂(02)(20230725)
第九十四期 19c OCM之路-第四堂(02)(20230725) 第四堂继续! 考点3:SQL statement tuning SQL语句调优 收集Schema统计信息 exec dbms_stats.gather_schems_stats(HR);开启制定表索引监控 create index…...
以智慧监测模式守护燃气安全 ,汉威科技“传感芯”凸显智慧力
城市燃气工程作为城市基建的重要组成部分,与城市居民生活、工业生产紧密相关。提升城市燃气服务质量和安全水平,也一直是政府和民众关注的大事。然而,近年来居民住宅、餐饮等工商业场所燃气事故频发,时刻敲响的警钟也折射出我国在…...
【阅读笔记】一种暗通道优先的快速自动白平衡算法
解决问题: 自动白平衡算法中存在白色区域检测错误导致白平衡失效的问题,作者提出了一种基于暗通道优先的白平衡算法。 算法思想: 图像中白色区域或者高饱和度区域的光线透射率较低,根据以上特性利用暗通道法计算图像中白色区域。 算法概述: 作者使用何凯明提出的基于暗…...
OpenStack之云主机管理
一)必备知识 1.云主机与快照管理 a-云主机管理 云主机管理是OpenStack云计算平台的核心功能,通常,云主机的管理包括创建、删除、查询等。可使用以下命令对OpenStack的云主机进行管理: openstack server <操作><云主机…...
Linux系列---【Ubuntu 20.04安装KVM】
Ubuntu 20.04安装KVM 一、安装kvm 1.安装kvm sudo apt install qemu-kvm libvirt-daemon-system libvirt-clients bridge-utils 2. 将当前用户添加至libvirt 、 kvm组 sudo adduser $USER libvirt sudo adduser $USER kvm 3.验证安装 virsh list --all 4.启动libvert sudo syst…...
RestClient
什么是RestClient RestClient 是 Elasticsearch 官方提供的 Java 低级 REST 客户端,它允许HTTP与Elasticsearch 集群通信,而无需处理 JSON 序列化/反序列化等底层细节。它是 Elasticsearch Java API 客户端的基础。 RestClient 主要特点 轻量级ÿ…...
深入浅出:JavaScript 中的 `window.crypto.getRandomValues()` 方法
深入浅出:JavaScript 中的 window.crypto.getRandomValues() 方法 在现代 Web 开发中,随机数的生成看似简单,却隐藏着许多玄机。无论是生成密码、加密密钥,还是创建安全令牌,随机数的质量直接关系到系统的安全性。Jav…...
质量体系的重要
质量体系是为确保产品、服务或过程质量满足规定要求,由相互关联的要素构成的有机整体。其核心内容可归纳为以下五个方面: 🏛️ 一、组织架构与职责 质量体系明确组织内各部门、岗位的职责与权限,形成层级清晰的管理网络…...
Mac软件卸载指南,简单易懂!
刚和Adobe分手,它却总在Library里给你写"回忆录"?卸载的Final Cut Pro像电子幽灵般阴魂不散?总是会有残留文件,别慌!这份Mac软件卸载指南,将用最硬核的方式教你"数字分手术"࿰…...
ardupilot 开发环境eclipse 中import 缺少C++
目录 文章目录 目录摘要1.修复过程摘要 本节主要解决ardupilot 开发环境eclipse 中import 缺少C++,无法导入ardupilot代码,会引起查看不方便的问题。如下图所示 1.修复过程 0.安装ubuntu 软件中自带的eclipse 1.打开eclipse—Help—install new software 2.在 Work with中…...
大数据学习(132)-HIve数据分析
🍋🍋大数据学习🍋🍋 🔥系列专栏: 👑哲学语录: 用力所能及,改变世界。 💖如果觉得博主的文章还不错的话,请点赞👍收藏⭐️留言Ǵ…...
Maven 概述、安装、配置、仓库、私服详解
目录 1、Maven 概述 1.1 Maven 的定义 1.2 Maven 解决的问题 1.3 Maven 的核心特性与优势 2、Maven 安装 2.1 下载 Maven 2.2 安装配置 Maven 2.3 测试安装 2.4 修改 Maven 本地仓库的默认路径 3、Maven 配置 3.1 配置本地仓库 3.2 配置 JDK 3.3 IDEA 配置本地 Ma…...
Java线上CPU飙高问题排查全指南
一、引言 在Java应用的线上运行环境中,CPU飙高是一个常见且棘手的性能问题。当系统出现CPU飙高时,通常会导致应用响应缓慢,甚至服务不可用,严重影响用户体验和业务运行。因此,掌握一套科学有效的CPU飙高问题排查方法&…...
JVM虚拟机:内存结构、垃圾回收、性能优化
1、JVM虚拟机的简介 Java 虚拟机(Java Virtual Machine 简称:JVM)是运行所有 Java 程序的抽象计算机,是 Java 语言的运行环境,实现了 Java 程序的跨平台特性。JVM 屏蔽了与具体操作系统平台相关的信息,使得 Java 程序只需生成在 JVM 上运行的目标代码(字节码),就可以…...
浪潮交换机配置track检测实现高速公路收费网络主备切换NQA
浪潮交换机track配置 项目背景高速网络拓扑网络情况分析通信线路收费网络路由 收费汇聚交换机相应配置收费汇聚track配置 项目背景 在实施省内一条高速公路时遇到的需求,本次涉及的主要是收费汇聚交换机的配置,浪潮网络设备在高速项目很少,通…...
