当前位置: 首页 > news >正文

线性神经网络——softmax 回归随笔【深度学习】【PyTorch】【d2l】

文章目录

    • 3.2、softmax 回归
      • 3.2.1、softmax运算
      • 3.2.2、交叉熵损失函数
      • 3.2.3、PyTorch 从零实现 softmax 回归
      • 3.2.4、简单实现 softmax 回归

在这里插入图片描述

3.2、softmax 回归

3.2.1、softmax运算

在这里插入图片描述

softmax 函数是一种常用的激活函数,用于将实数向量转换为概率分布向量。它在多类别分类问题中起到重要的作用,并与交叉熵损失函数结合使用。

y ^ = s o f t m a x ( o ) 其中     y ^ i = e x p ( o j ) ∑ k e x p ( o k ) \hat{y} = softmax(o) \ \ \ \ \ 其中\ \ \ \ \hat{y}_i = \frac{exp(o_j)}{\sum_{k}exp(o_k)} y^=softmax(o)     其中    y^i=kexp(ok)exp(oj)

其中,O为小批量的未规范化的预测, Y ^ \hat{Y} Y^w为输出概率,是一个正确的概率分布【 ∑ y i = 1 \sum{y_i} =1 yi=1

3.2.2、交叉熵损失函数

通过测量给定模型编码的比特位,来衡量两概率分布之间的差异,是分类问题中常用的 loss 函数。

H ( P , Q ) = − Σ P ( x ) ∗ l o g ( Q ( x ) ) H(P, Q) = -Σ P(x) * log(Q(x)) H(P,Q)=ΣP(x)log(Q(x))

真实概率分布是从哪里得知的?

真实标签的概率分布是由数据集中的标签信息提供的,通常使用单热编码表示。

softmax() 如何与交叉熵函数搭配的?

softmax 函数与交叉熵损失函数常用于多分类任务中。softmax 函数用于将模型输出转化为概率分布形式,交叉熵损失函数用于衡量模型输出概率分布与真实标签的差异,并通过优化算法来最小化损失函数,从而训练出更准确的分类模型。

3.2.3、PyTorch 从零实现 softmax 回归

(非完整代码)

#在 Notebook 中内嵌绘图
%matplotlib inline
import torch
import torchvision
from torch.utils import data
from torchvision import transforms
from d2l import torch as d2l#,将图形显示格式设置为 SVG 格式,以在 Notebook 中以矢量图形的形式显示图像。这有助于提高图像的清晰度和可缩放性。
d2l .use_svg_display()

在线下载数据集 Fashion-MNIST

#将图像数据转换为张量形式
trans = transforms.ToTensor()
mnist_train = torchvision.datasets.FashionMNIST(root="../data",train=True,transform=trans,download=True)
mnist_test = torchvision.datasets.FashionMNIST(root="../data",train=False,transform =trans,download=True)len(mnist_train),len(mnist_test)

绘图(略)

读取小批量数据集

batch_size = 256def get_dataloader_workers():"""使用4进程读取"""return 4train_iter = data.DataLoader(mnist_train,batch_size,shuffle=True,num_workers=get_dataloader_workers())
timer = d2l.Timer()
for X,y in train_iter:continue
print(f'{timer.stop():.2f}sec')

定义softmax操作

def softmax(X):X_exp = torch.exp(X)partition = X_exp.sum(1, keepdim=True)return X_exp / partition  # 这里应用了广播机制

定义损失函数

def cross_entropy(y_hat, y):return - torch.log(y_hat[range(len(y_hat)), y])cross_entropy(y_hat, y)

分类精度

def accuracy(y_hat, y):  #@save"""计算预测正确的数量"""if len(y_hat.shape) > 1 and y_hat.shape[1] > 1:y_hat = y_hat.argmax(axis=1)cmp = y_hat.type(y.dtype) == yreturn float(cmp.type(y.dtype).sum())

评估

def evaluate_accuracy(net, data_iter):  #@save"""计算在指定数据集上模型的精度"""if isinstance(net, torch.nn.Module):net.eval()  # 将模型设置为评估模式metric = Accumulator(2)  # 正确预测数、预测总数with torch.no_grad():for X, y in data_iter:metric.add(accuracy(net(X), y), y.numel())return metric[0] / metric[1]
class Accumulator:  #@save"""在n个变量上累加"""def __init__(self, n):self.data = [0.0] * ndef add(self, *args):self.data = [a + float(b) for a, b in zip(self.data, args)]def reset(self):self.data = [0.0] * len(self.data)def __getitem__(self, idx):return self.data[idx]

3.2.4、简单实现 softmax 回归

导入前面已下载数据集 Fashion-MNIST

import torch 
from torch import nn
from d2l import torch as d2lbatch_size =256
train_iter,test_iter = d2l.load_data_fashion_mnist(batch_size)

初始化模型

#nn.Flatten() 层的作用是将输入数据展平,将二维输入(如图像)转换为一维向量。因为线性层(nn.Linear)通常期望接收一维输入。
#nn.Linear(784,10) 将输入特征从 784 维降低到 10 维,用于图像分类问题中的 10 个类别的预测   784维向量->10维向量
net = nn.Sequential(nn.Flatten(),nn.Linear(784,10))def init_weights(m):if type(m) == nn.Linear:nn.init.normal_(m.weight,std=0.01)net.apply(init_weights);
#计算交叉熵损失函数,用于衡量模型预测与真实标签之间的差异。参数 reduction 控制了损失的计算方式。
#reduction='none' 表示不进行损失的降维或聚合操作,即返回每个样本的独立损失值。
loss = nn.CrossEntropyLoss(reduction='none')

优化算法

trainer = torch.optim.SGD(net.parameters(),lr=0.1)

训练

num_epochs = 10
d2l.train_ch3(net,train_iter,test_iter,loss,num_epochs,trainer)

相关文章:

线性神经网络——softmax 回归随笔【深度学习】【PyTorch】【d2l】

文章目录 3.2、softmax 回归3.2.1、softmax运算3.2.2、交叉熵损失函数3.2.3、PyTorch 从零实现 softmax 回归3.2.4、简单实现 softmax 回归 3.2、softmax 回归 3.2.1、softmax运算 softmax 函数是一种常用的激活函数,用于将实数向量转换为概率分布向量。它在多类别…...

【Nodejs】Node.js开发环境安装

1.版本介绍 在命令窗口中输入 node -v 可以查看版本 0.x 完全不技术 ES64.x 部分支持 ES6 特性5.x 部分支持ES6特性(比4.x多些),属于过渡产品,现在来说应该没有什么理由去用这个了6.x 支持98%的 ES6 特性8.x 支持 ES6 特性 2.No…...

梅尔频谱(Mel spectrum)简介及Python实现

梅尔频谱(Mel spectrum)简介及Python实现 1. 梅尔频谱(Mel spectrum)简介2. Python可视化测试3.频谱可视化3.1 Mel 频谱可视化3.2 STFT spectrum参考文献资料1. 梅尔频谱(Mel spectrum)简介 在信号处理上,声信号(噪声信号)是一种重要的传感监测手段。对于语音分类任务…...

【数据结构】实验六:队列

实验六 队列 一、实验目的与要求 1)熟悉C/C语言(或其他编程语言)的集成开发环境; 2)通过本实验加深对队列的理解,熟悉基本操作; 3) 结合具体的问题分析算法时间复杂度。 二、…...

【Linux线程】第一章||理解线程概念+创建一个线程(附代码加讲解)

线程概念 🌵什么是线程🌲线程和进程的关系🎄线程有以下特点:🌳 线程的优点🌴 线程的缺点🌱线程异常🌿线程用途 ☘️手动创建一个进程🍀运行 🌵什么是线程 在L…...

Android进阶之微信扫码登录

遇到新需求要搭建微信扫码登录功能,这篇文章是随着我的编码过程一并写的,希望能够帮助有需求的人和以后再次用到此功能的自己。 首先想到的就是百度各种文章,当然去开发者平台申请AppID和密钥是必不可少的,等注册好发现需要创建应用以及审核(要官网,流程图及其他信息),想着先写…...

macOS Monterey 12.6.8 (21G725) Boot ISO 原版可引导镜像

macOS Monterey 12.6.8 (21G725) Boot ISO 原版可引导镜像 本站下载的 macOS 软件包,既可以拖拽到 Applications(应用程序)下直接安装,也可以制作启动 U 盘安装,或者在虚拟机中启动安装。另外也支持在 Windows 和 Lin…...

Unity自定义后处理——用偏导数求图片颜色边缘

大家好,我是阿赵。   继续介绍屏幕后处理效果的做法。这次介绍一下用偏导数求图形边缘的技术。 一、原理介绍 先来看例子吧。   这个例子看起来好像是要给模型描边。之前其实也介绍过很多描边的方法,比如沿着法线方向放大模型,或者用Ndo…...

本地Git仓库和GitHub仓库SSH传输

SSH创建命令解释 ssh-keygen 用于创建密钥的程序 -m PEM 将密钥的格式设为 PEM -t rsa 要创建的密钥类型,本例中为 RSA 格式 -b 4096 密钥的位数,本例中为 4096 -C “azureusermyserver” 追加到公钥文件末尾以便于识别的注释。 通常以电子邮件地址…...

【C++11】——右值引用、移动语义

目录 1. 基本概念 1.1 左值与左值引用 1.2 右值和右值引用 1.3 左值引用与右值引用 2. 右值引用实用场景和意义 2.1 左值引用的使用场景 2.2 左值引用的短板 2.3 右值引用和移动语义 2.3.1 移动构造 2.3.2 移动赋值 2.3.3 编译器做的优化 2.3.4 总结 2.4 右值引用…...

消息服务概述

消息服务的作用: 在多数应用尤其是分布式系统中,消息服务是不可或缺的重要部分,它使用起来比较简单,同时解决了不少难题,例如异步处理、应用解耦、流量削锋、分布式事务管理等,使用消息服务可以实现一个高…...

【Spring Boot】Web开发 — 数据验证

Web开发 — 数据验证 对于应用系统而言,任何客户端传入的数据都不是绝对安全有效的,这就要求我们在服务端接收到数据时也对数据的有效性进行验证,以确保传入的数据安全正确。接下来介绍Spring Boot是如何实现数据验证的。 1.Hibernate Vali…...

技术分享 | App常见bug解析

功能Bug 内容显示错误 前端页面展示的内容有误。 这种错误的产生有两种可能 1、前端代码写的文案错误 2、接口返回值错误 功能错误 功能错误是在测试过程中最常见的类型之一,也就是产品的功能没有实现。比如图中的公众号登录不成功的问题。 界面展示错乱 产…...

树莓派Pico|RP2040|使用SWD进行调试|构建 “Hello World“ debug版本

文章目录 使用SWD进行调试构建 "Hello World" debug版本安装 GDB使用 GDB 和 OpenOCD 来 debug Hello World TIP重要提示 使用SWD进行调试 基于rp2040的板上的SWD端口重置,加载和运行代码,如树莓派Pico可用于交互式调试已加载的程序。这包括:…...

Ubuntu18.04 下配置Clion

配置Clion 安装gcc、g、make Ubuntu中用到的编译工具是gcc©,g(C),make(连接)。因此只需安装对应的工具包即可。Ubuntu下使用命令安装这些包: (1)安装gcc sudo apt install gcc&am…...

数据库管理-第九十四期 19c OCM之路-第四堂(02)(20230725)

第九十四期 19c OCM之路-第四堂(02)(20230725) 第四堂继续! 考点3:SQL statement tuning SQL语句调优 收集Schema统计信息 exec dbms_stats.gather_schems_stats(HR);开启制定表索引监控 create index…...

以智慧监测模式守护燃气安全 ,汉威科技“传感芯”凸显智慧力

城市燃气工程作为城市基建的重要组成部分,与城市居民生活、工业生产紧密相关。提升城市燃气服务质量和安全水平,也一直是政府和民众关注的大事。然而,近年来居民住宅、餐饮等工商业场所燃气事故频发,时刻敲响的警钟也折射出我国在…...

【阅读笔记】一种暗通道优先的快速自动白平衡算法

解决问题: 自动白平衡算法中存在白色区域检测错误导致白平衡失效的问题,作者提出了一种基于暗通道优先的白平衡算法。 算法思想: 图像中白色区域或者高饱和度区域的光线透射率较低,根据以上特性利用暗通道法计算图像中白色区域。 算法概述: 作者使用何凯明提出的基于暗…...

OpenStack之云主机管理

一&#xff09;必备知识 1.云主机与快照管理 a-云主机管理 云主机管理是OpenStack云计算平台的核心功能&#xff0c;通常&#xff0c;云主机的管理包括创建、删除、查询等。可使用以下命令对OpenStack的云主机进行管理&#xff1a; openstack server <操作><云主机…...

Linux系列---【Ubuntu 20.04安装KVM】

Ubuntu 20.04安装KVM 一、安装kvm 1.安装kvm sudo apt install qemu-kvm libvirt-daemon-system libvirt-clients bridge-utils 2. 将当前用户添加至libvirt 、 kvm组 sudo adduser $USER libvirt sudo adduser $USER kvm 3.验证安装 virsh list --all 4.启动libvert sudo syst…...

KubeSphere 容器平台高可用:环境搭建与可视化操作指南

Linux_k8s篇 欢迎来到Linux的世界&#xff0c;看笔记好好学多敲多打&#xff0c;每个人都是大神&#xff01; 题目&#xff1a;KubeSphere 容器平台高可用&#xff1a;环境搭建与可视化操作指南 版本号: 1.0,0 作者: 老王要学习 日期: 2025.06.05 适用环境: Ubuntu22 文档说…...

【ROS】Nav2源码之nav2_behavior_tree-行为树节点列表

1、行为树节点分类 在 Nav2(Navigation2)的行为树框架中,行为树节点插件按照功能分为 Action(动作节点)、Condition(条件节点)、Control(控制节点) 和 Decorator(装饰节点) 四类。 1.1 动作节点 Action 执行具体的机器人操作或任务,直接与硬件、传感器或外部系统…...

多模态大语言模型arxiv论文略读(108)

CROME: Cross-Modal Adapters for Efficient Multimodal LLM ➡️ 论文标题&#xff1a;CROME: Cross-Modal Adapters for Efficient Multimodal LLM ➡️ 论文作者&#xff1a;Sayna Ebrahimi, Sercan O. Arik, Tejas Nama, Tomas Pfister ➡️ 研究机构: Google Cloud AI Re…...

在WSL2的Ubuntu镜像中安装Docker

Docker官网链接: https://docs.docker.com/engine/install/ubuntu/ 1、运行以下命令卸载所有冲突的软件包&#xff1a; for pkg in docker.io docker-doc docker-compose docker-compose-v2 podman-docker containerd runc; do sudo apt-get remove $pkg; done2、设置Docker…...

什么是Ansible Jinja2

理解 Ansible Jinja2 模板 Ansible 是一款功能强大的开源自动化工具&#xff0c;可让您无缝地管理和配置系统。Ansible 的一大亮点是它使用 Jinja2 模板&#xff0c;允许您根据变量数据动态生成文件、配置设置和脚本。本文将向您介绍 Ansible 中的 Jinja2 模板&#xff0c;并通…...

iOS性能调优实战:借助克魔(KeyMob)与常用工具深度洞察App瓶颈

在日常iOS开发过程中&#xff0c;性能问题往往是最令人头疼的一类Bug。尤其是在App上线前的压测阶段或是处理用户反馈的高发期&#xff0c;开发者往往需要面对卡顿、崩溃、能耗异常、日志混乱等一系列问题。这些问题表面上看似偶发&#xff0c;但背后往往隐藏着系统资源调度不当…...

Linux nano命令的基本使用

参考资料 GNU nanoを使いこなすnano基础 目录 一. 简介二. 文件打开2.1 普通方式打开文件2.2 只读方式打开文件 三. 文件查看3.1 打开文件时&#xff0c;显示行号3.2 翻页查看 四. 文件编辑4.1 Ctrl K 复制 和 Ctrl U 粘贴4.2 Alt/Esc U 撤回 五. 文件保存与退出5.1 Ctrl …...

BLEU评分:机器翻译质量评估的黄金标准

BLEU评分&#xff1a;机器翻译质量评估的黄金标准 1. 引言 在自然语言处理(NLP)领域&#xff0c;衡量一个机器翻译模型的性能至关重要。BLEU (Bilingual Evaluation Understudy) 作为一种自动化评估指标&#xff0c;自2002年由IBM的Kishore Papineni等人提出以来&#xff0c;…...

GraphQL 实战篇:Apollo Client 配置与缓存

GraphQL 实战篇&#xff1a;Apollo Client 配置与缓存 上一篇&#xff1a;GraphQL 入门篇&#xff1a;基础查询语法 依旧和上一篇的笔记一样&#xff0c;主实操&#xff0c;没啥过多的细节讲解&#xff0c;代码具体在&#xff1a; https://github.com/GoldenaArcher/graphql…...

海云安高敏捷信创白盒SCAP入选《中国网络安全细分领域产品名录》

近日&#xff0c;嘶吼安全产业研究院发布《中国网络安全细分领域产品名录》&#xff0c;海云安高敏捷信创白盒&#xff08;SCAP&#xff09;成功入选软件供应链安全领域产品名录。 在数字化转型加速的今天&#xff0c;网络安全已成为企业生存与发展的核心基石&#xff0c;为了解…...