当前位置: 首页 > news >正文

动手学深度学习(三)线性神经网络—softmax回归

分类任务是对离散变量预测,通过比较分类的概率来判断预测的结果。

softmax回归和线性回归一样也是将输入特征与权重做线性叠加,但是softmax回归的输出值个数等于标签中的类别数,这样就可以用于预测分类问题。

分类问题和线性回归的区别:分类任务通常有多个输出,作为不同类别的置信度。

一、softmax回归

1.1 网络架构

为了解决线性模型的分类问题,我们需要和输出一样多的仿射函数,每个输出对应它自己的仿射函数。

与线性回归一样,softmax回归也是一个单层神经网络。

在softmax回归中,输出层的输出值大小就代表其所属类别的置信度大小,置信度最大的那个类别我们将其作为预测。

1.2 softmax运算

首先,分类任务的目标是通过比较每个类别的置信度大小来判断预测的结果。但是,我们不能选择未规范化的最大输出值的 o_i 的类别作为我们的预测,原因有两点:

1. 输出值 o_i的总和不一定为1

2. 输出值 o_i有可能为负数。

这违反了概率论基本公理,很难判断所预测的类别是否真符合真实值。

softmax函数通过如下公式,解决了以上问题

softmax函数确保了输出值的非负,和为1,这一种规范手段。

1.3 交叉熵损失函数

交叉熵损失常用来衡量两个概率之间的差别

根据公式推断, 交叉熵损失函数的偏导数是我们softmax函数分配的概率与实际发生的情况之间的差距,换句话来说,其梯度是真实概率 y 和预测概率 \hat{y} 之间的差距。

二、图像分类数据集

MNIST数据集是图像分类中广泛使用的数据集之一,但作为基准数据集过于简单。我们将使用类似但更复杂的Fashion-MNIST数据集。

2.1 导包

import torch
import torchvision
from torch.utils import data
from torchvision import transforms
from d2l import torch as d2ld2l.use_svg_display()  # 用SVG显示图片

2.2 创建数据集

通过框架中的内置函数将Fashion-MNIST数据集下载并读取到内存中。

# 通过ToTensor实例将图像数据从PIL类型转化成32位的浮点数格式
# 并除以255使得所有像素的数值均在0到1之间
trans = transforms.ToTensor()
mnist_train = torchvision.datasets.FashionMNIST(root="../data", train=True, transform=trans, download=True)
mnist_test = torchvision.datasets.FashionMNIST(root="../data", train=False, transform=trans, download=True)

查看Fashion-MNIST训练集和测试集大小,分别包含60000,10000张图片。

print(len(mnist_train), len(mnist_test))

查看图片分辨率,图片分辨率大小为[1, 28, 28]。

print(mnist_train[0][0].shape)

补充

torchvision.datasets 是Torchvision提供的标准数据集。

torchvision.transforms是包含一系列常用图像变换方法的包,可用于图像预处理、数据增强等工作。

torchvision.transforms.ToTensor()把一个取值范围是[0,255]PIL.Image或者shape(H,W,C)numpy.ndarray,转换成形状为[C,H,W],取值范围是[0,1.0]torch.FloadTensor(浮点型的tensor)。

 

2.3 可视化数据集函数

# 可视化数据集函数
def get_fashion_mnist_labels(labels):"""返回Fashion-MNIST数据集的文本标签"""text_labels = ['t-shirt', 'trouser', 'pullover', 'dress', 'coat','sandal', 'shirt', 'sneaker', 'bag', 'ankle boot']return [text_labels[int(i)] for i in labels]def show_images(imgs, num_rows, num_cols, titles=None, scale=1.5):"""绘制图像列表"""figsize = (num_cols * scale, num_rows * scale)_, axes = d2l.plt.subplots(num_rows, num_cols, figsize=figsize) # 创建绘制num_rows*num_cols个子图的位置区域axes = axes.flatten() # 降维成一维数组for i, (ax, img) in enumerate(zip(axes, imgs)):if torch.is_tensor(img):# 图片张量ax.imshow(img.numpy()) # 负责对图像进行处理,并存入内存,并不显示else:# PIL图片ax.imshow(img)ax.axes.get_xaxis().set_visible(False) #不显示y轴ax.axes.get_yaxis().set_visible(False) #不显示x轴if titles:ax.set_title(titles[i])return axes

可视化展示训练集中前18个图片。

X, y = next(iter(data.DataLoader(mnist_train, batch_size=18)))
show_images(X.reshape(18, 28, 28), 2, 9, titles=get_fashion_mnist_labels(y))
d2l.plt.show() # 将plt.imshow()处理后的数据显示出来

plt.subplots(num_rows, num_cols, figsize):创建绘制num_rows*num_cols个子图的位置区域,其中子图大小为figsize。

enumerate():获取可迭代对象的每个元素的索引值及该元素值。

zip():用于将可迭代的对象作为参数,将对象中对应的元素打包成一个个元组,然后返回由这些元组组成的列表。

imshow():负责对图像进行处理,并存入内存,并不显示。

plt.show():将plt.imshow()处理后的数据显示出来。

2.4 读取小批量

使用4个进程,以批量大小为256,来读取数据集。

# 读取小批量
batch_size = 256def get_dataloader_workers(): """使用4个进程来读取数据"""return 4train_iter = data.DataLoader(mnist_train, batch_size, shuffle=True,num_workers=get_dataloader_workers())

2.5 整合所有组件

这个函数包含了以上所有工作。

def load_data_fashion_mnist(batch_size, resize=None):"""下载Fashion-MNIST数据集,然后将其加载到内存中"""trans = [transforms.ToTensor()]if resize:trans.insert(0, transforms.Resize(resize)) #修改图片大小trans = transforms.Compose(trans)mnist_train = torchvision.datasets.FashionMNIST(root="../data", train=True, transform=trans, download=True)mnist_test = torchvision.datasets.FashionMNIST(root="../data", train=False, transform=trans, download=True)return (data.DataLoader(mnist_train, batch_size, shuffle=True,num_workers=get_dataloader_workers()),data.DataLoader(mnist_test, batch_size, shuffle=False,num_workers=get_dataloader_workers()))

相关文章:

动手学深度学习(三)线性神经网络—softmax回归

分类任务是对离散变量预测,通过比较分类的概率来判断预测的结果。 softmax回归和线性回归一样也是将输入特征与权重做线性叠加,但是softmax回归的输出值个数等于标签中的类别数,这样就可以用于预测分类问题。 分类问题和线性回归的区别&#…...

ios swift alert 自定义弹框 点击半透明部分弹框消失

文章目录 1.BaseAlertVC2.BindFrameNumAlertVC 1.BaseAlertVC import UIKitclass BaseAlertVC: GLBaseViewController {let centerView UIView()override func viewDidLoad() {super.viewDidLoad()view.backgroundColor UIColor(displayP3Red: 0, green: 0, blue: 0, alpha:…...

HCIP STP(生成树)

目录 一、STP概述 二、生成树协议原理 三、802.1D生成树 四、STP的配置BPDU 1、配置BPDU的报文格式 2、配置BPDU的工作过程 3、TCN BPDU 4、TCN BPDU的工作过程 五、STP角色选举 1、根网桥选举 2、根端口选举 3、指定端口选举 4、非指定端口选举 六、STP的接口状…...

【Unity开发必备】100多个 Unity 学习网址 资源 收藏整理大全【持续更新】

Unity 相关网站整理大全 众所周知,工欲善其事必先利其器,有一个好的工具可以让我们事半功倍,有一个好用的网站更是如此! 但是好用的网站真的太多了,收藏夹都满满的(但是几乎没打开用过😁)。 所以本文是对…...

Alpine Ridge控制器使其具备多种使用模式 - 英特尔发布雷电3接口:竟和USB Type-C统一了

同时又因为这建立在Type-C的基础上,雷电3也将利用现有的标准Type-C线缆引入有源支持。当使用Type-C的线缆时,雷电的速度就降到了20Gbps全双工——这与普通的Type-C的带宽相同——这是为了成本牺牲了一些带宽。可以比较一下,Type-C线的成本只有…...

容器——2.Collection 子接口之 List

文章目录 2.1. Arraylist 和 Vector 的区别?2.2. Arraylist 与 LinkedList 区别?2.2.1. 补充内容:双向链表和双向循环链表2.2.2. 补充内容:RandomAccess 接口 2.3 ArrayList 的扩容机制 2.1. Arraylist 和 Vector 的区别? ArrayList 是 List 的主要实现类,底层使…...

【工作记录】docker安装gitlab、重置密码@20230809

前言 本文记录下基于docker安装gitlab并重置管理员密码的过程。 作为记录的同时也希望能帮助到需要的朋友们。 搭建过程 1. 准备好docker环境并启动docker [rootslave-node1 docker-gitlab]# docker version Client:Version: 18.06.1-ceAPI version: 1.38…...

数据挖掘的基本概念和大数据的特点

数据挖掘是指从大量数据中提取有价值的信息或模式的过程。它通常使用计算机技术来分析数据,并利用统计学、机器学习、人工智能等方法来发现数据中的隐藏规律、趋势和关联性。 数据挖掘的基本概念包括以下几个方面: 数据预处理:对原始数据进行…...

LabVIEW开发分段反射器测试台

LabVIEW开发分段反射器测试台 随着对太空的观察需求越来越远,而不是当前技术(如哈勃望远镜)所能达到的,有必要增加太空望远镜主镜的尺寸。但是,增加主镜像的大小时存在几个问题。随着反射镜尺寸的增加,制造…...

二级python和二级c哪个简单,二级c语言和二级python

大家好,小编为大家解答二级c语言和二级office一起报可以吗的问题。很多人还不知道计算机二级c语言和python哪个好考,现在让我们一起来看看吧! 介绍Python有很多库和使用Qt编写的接口,这自然创建c调用Python的需求。一路摸索,充满艰辛的添加头…...

E: Package ‘curl‘ has no installation candidate/ E:软件包没有可用的安装源

解决方案: 访问etc/apt/source.list 修改或者添加安装源 不用版本的Linux 有不同的配置比如我的是Debian 12 其他版本的去搜索引擎搜索即可 vim /etc/apt/source.list 改成修改或添加 // 以下是官方示例 deb http://deb.debian.org/debian bookworm main non-…...

代理模式及常见的3种代理类型对比

代理模式及常见的3种代理类型对比 代理模式代理模式分类静态代理JDK动态代理CGLIB动态代理Fastclass机制 三种代理方式之间对比常见问题 代理模式 代理模式是一种设计模式,提供了对目标对象额外的访问方式,即通过代理对象访问目标对象,这样可…...

8.6 校招 内推 面经

绿泡泡: neituijunsir 交流裙,内推/实习/校招汇总表格 1、面经 | 车载测试-23 面经 | 车载测试-23 2、校招 | 荣耀2024届全球校园招聘启动(内推) 校招 | 荣耀2024届全球校园招聘启动(内推) 3、校招 |…...

【大数据之Flume】七、Flume进阶之自定义Sink

(1)概述:   Sink 不断地轮询 Channel 中的事件且批量地移除它们,并将这些事件批量写入到存储或索引系统、或者被发送到另一个 Flume Agent。 Sink 是完全事务性的。在从 Channel 批量删除数据之前,每个 Sink 用 Chan…...

vue对于时间的处理

2023-08-05 11:25:45 假如这个就是我们要传的时间字符串 比如今天是2023-08-05(同一天):现在把这个时间字符串传入到 formatDate()这个方法,就会给你返回 11:25 比如今天是2023-08-06(前一天&a…...

Apache DolphinScheduler 3.1.8 版本发布,修复 SeaTunnel 相关 Bug

近日,Apache DolphinScheduler 发布了 3.1.8 版本。此版本主要基于 3.1.7 版本进行了 bug 修复,共计修复 16 个 bug, 1 个 doc, 2 个 chore。 其中修复了以下几个较为重要的问题: 修复在构建 SeaTunnel 任务节点的参数时错误的判断条件修复 …...

科技云报道:一波未平一波又起?AI大模型再出邪恶攻击工具

AI大模型的快速向前奔跑,让我们见识到了AI的无限可能,但也展示了AI在虚假信息、深度伪造和网络攻击方面的潜在威胁。 据安全分析平台Netenrich报道,近日,一款名为FraudGPT的AI工具近期在暗网上流通,并被犯罪分子用于编…...

深度对话|如何设计合适的网络经济激励措施

近日,我们与Mysten Labs的首席经济学家Alonso de Gortari进行了对话,讨论了如何在网络运营商和参与者之间找到激励措施的平衡,以及Sui的经济如何不断发展。 是什么让您选择将自己的经济学背景应用于区块链和Web3领域? 起初&…...

opencv带GStreamer之Windows编译

目录 1、下载GStreamer和安装2. GSTReamer CMake配置3. 验证是否配置成功 1、下载GStreamer和安装 下载地址如下: gstreamer-1.0-msvc-x86_64-1.18.2.msi gstreamer-1.0-devel-msvc-x86_64-1.18.2.msi 安装目录无要求,主要是安装完设置环境变量 xxx\1…...

Java并发编程之锁的升级

Java 中的锁机制是多线程编程中的一部分。锁一共有4种状态,级别从低到高依次是:无锁状态、偏向锁状态、轻量级锁状态和重量级锁状态,这几个状态会随着竞争情况逐渐升级。 锁可以升级但不能降级,意味着偏向锁升级成轻量级锁后不能…...

使用分级同态加密防御梯度泄漏

抽象 联邦学习 (FL) 支持跨分布式客户端进行协作模型训练,而无需共享原始数据,这使其成为在互联和自动驾驶汽车 (CAV) 等领域保护隐私的机器学习的一种很有前途的方法。然而,最近的研究表明&…...

解锁数据库简洁之道:FastAPI与SQLModel实战指南

在构建现代Web应用程序时,与数据库的交互无疑是核心环节。虽然传统的数据库操作方式(如直接编写SQL语句与psycopg2交互)赋予了我们精细的控制权,但在面对日益复杂的业务逻辑和快速迭代的需求时,这种方式的开发效率和可…...

Golang dig框架与GraphQL的完美结合

将 Go 的 Dig 依赖注入框架与 GraphQL 结合使用,可以显著提升应用程序的可维护性、可测试性以及灵活性。 Dig 是一个强大的依赖注入容器,能够帮助开发者更好地管理复杂的依赖关系,而 GraphQL 则是一种用于 API 的查询语言,能够提…...

macOS多出来了:Google云端硬盘、YouTube、表格、幻灯片、Gmail、Google文档等应用

文章目录 问题现象问题原因解决办法 问题现象 macOS启动台(Launchpad)多出来了:Google云端硬盘、YouTube、表格、幻灯片、Gmail、Google文档等应用。 问题原因 很明显,都是Google家的办公全家桶。这些应用并不是通过独立安装的…...

第25节 Node.js 断言测试

Node.js的assert模块主要用于编写程序的单元测试时使用,通过断言可以提早发现和排查出错误。 稳定性: 5 - 锁定 这个模块可用于应用的单元测试,通过 require(assert) 可以使用这个模块。 assert.fail(actual, expected, message, operator) 使用参数…...

Spring AI 入门:Java 开发者的生成式 AI 实践之路

一、Spring AI 简介 在人工智能技术快速迭代的今天,Spring AI 作为 Spring 生态系统的新生力量,正在成为 Java 开发者拥抱生成式 AI 的最佳选择。该框架通过模块化设计实现了与主流 AI 服务(如 OpenAI、Anthropic)的无缝对接&…...

docker 部署发现spring.profiles.active 问题

报错: org.springframework.boot.context.config.InvalidConfigDataPropertyException: Property spring.profiles.active imported from location class path resource [application-test.yml] is invalid in a profile specific resource [origin: class path re…...

Unsafe Fileupload篇补充-木马的详细教程与木马分享(中国蚁剑方式)

在之前的皮卡丘靶场第九期Unsafe Fileupload篇中我们学习了木马的原理并且学了一个简单的木马文件 本期内容是为了更好的为大家解释木马(服务器方面的)的原理,连接,以及各种木马及连接工具的分享 文件木马:https://w…...

免费PDF转图片工具

免费PDF转图片工具 一款简单易用的PDF转图片工具,可以将PDF文件快速转换为高质量PNG图片。无需安装复杂的软件,也不需要在线上传文件,保护您的隐私。 工具截图 主要特点 🚀 快速转换:本地转换,无需等待上…...

push [特殊字符] present

push 🆚 present 前言present和dismiss特点代码演示 push和pop特点代码演示 前言 在 iOS 开发中,push 和 present 是两种不同的视图控制器切换方式,它们有着显著的区别。 present和dismiss 特点 在当前控制器上方新建视图层级需要手动调用…...