动手学深度学习(三)线性神经网络—softmax回归
分类任务是对离散变量预测,通过比较分类的概率来判断预测的结果。
softmax回归和线性回归一样也是将输入特征与权重做线性叠加,但是softmax回归的输出值个数等于标签中的类别数,这样就可以用于预测分类问题。
分类问题和线性回归的区别:分类任务通常有多个输出,作为不同类别的置信度。
一、softmax回归
1.1 网络架构
为了解决线性模型的分类问题,我们需要和输出一样多的仿射函数,每个输出对应它自己的仿射函数。
与线性回归一样,softmax回归也是一个单层神经网络。
在softmax回归中,输出层的输出值大小就代表其所属类别的置信度大小,置信度最大的那个类别我们将其作为预测。
1.2 softmax运算
首先,分类任务的目标是通过比较每个类别的置信度大小来判断预测的结果。但是,我们不能选择未规范化的最大输出值的 的类别作为我们的预测,原因有两点:
1. 输出值
的总和不一定为1
2. 输出值
有可能为负数。
这违反了概率论基本公理,很难判断所预测的类别是否真符合真实值。
softmax函数通过如下公式,解决了以上问题:
softmax函数确保了输出值的非负,和为1,这一种规范手段。
1.3 交叉熵损失函数
交叉熵损失常用来衡量两个概率之间的差别。
根据公式推断, 交叉熵损失函数的偏导数是我们softmax函数分配的概率与实际发生的情况之间的差距,换句话来说,其梯度是真实概率 和预测概率
之间的差距。
二、图像分类数据集
MNIST数据集是图像分类中广泛使用的数据集之一,但作为基准数据集过于简单。我们将使用类似但更复杂的Fashion-MNIST数据集。
2.1 导包
import torch
import torchvision
from torch.utils import data
from torchvision import transforms
from d2l import torch as d2ld2l.use_svg_display() # 用SVG显示图片
2.2 创建数据集
通过框架中的内置函数将Fashion-MNIST数据集下载并读取到内存中。
# 通过ToTensor实例将图像数据从PIL类型转化成32位的浮点数格式
# 并除以255使得所有像素的数值均在0到1之间
trans = transforms.ToTensor()
mnist_train = torchvision.datasets.FashionMNIST(root="../data", train=True, transform=trans, download=True)
mnist_test = torchvision.datasets.FashionMNIST(root="../data", train=False, transform=trans, download=True)
查看Fashion-MNIST训练集和测试集大小,分别包含60000,10000张图片。
print(len(mnist_train), len(mnist_test))
查看图片分辨率,图片分辨率大小为[1, 28, 28]。
print(mnist_train[0][0].shape)
补充:
torchvision.datasets
是Torchvision提供的标准数据集。
torchvision.transforms
是包含一系列常用图像变换方法的包,可用于图像预处理、数据增强等工作。
torchvision.transforms.ToTensor()
把一个取值范围是[0,255]
的PIL.Image
或者shape
为(H,W,C)
的numpy.ndarray
,转换成形状为[C,H,W]
,取值范围是[0,1.0]
的torch.FloadTensor(浮点型的tensor)。
2.3 可视化数据集函数
# 可视化数据集函数
def get_fashion_mnist_labels(labels):"""返回Fashion-MNIST数据集的文本标签"""text_labels = ['t-shirt', 'trouser', 'pullover', 'dress', 'coat','sandal', 'shirt', 'sneaker', 'bag', 'ankle boot']return [text_labels[int(i)] for i in labels]def show_images(imgs, num_rows, num_cols, titles=None, scale=1.5):"""绘制图像列表"""figsize = (num_cols * scale, num_rows * scale)_, axes = d2l.plt.subplots(num_rows, num_cols, figsize=figsize) # 创建绘制num_rows*num_cols个子图的位置区域axes = axes.flatten() # 降维成一维数组for i, (ax, img) in enumerate(zip(axes, imgs)):if torch.is_tensor(img):# 图片张量ax.imshow(img.numpy()) # 负责对图像进行处理,并存入内存,并不显示else:# PIL图片ax.imshow(img)ax.axes.get_xaxis().set_visible(False) #不显示y轴ax.axes.get_yaxis().set_visible(False) #不显示x轴if titles:ax.set_title(titles[i])return axes
可视化展示训练集中前18个图片。
X, y = next(iter(data.DataLoader(mnist_train, batch_size=18)))
show_images(X.reshape(18, 28, 28), 2, 9, titles=get_fashion_mnist_labels(y))
d2l.plt.show() # 将plt.imshow()处理后的数据显示出来
plt.subplots(num_rows, num_cols, figsize):创建绘制num_rows*num_cols个子图的位置区域,其中子图大小为figsize。
enumerate()
:获取可迭代对象的每个元素的索引值及该元素值。
zip()
:用于将可迭代的对象作为参数,将对象中对应的元素打包成一个个元组,然后返回由这些元组组成的列表。
imshow()
:负责对图像进行处理,并存入内存,并不显示。
plt.show()
:将plt.imshow()处理后的数据显示出来。
2.4 读取小批量
使用4个进程,以批量大小为256,来读取数据集。
# 读取小批量
batch_size = 256def get_dataloader_workers(): """使用4个进程来读取数据"""return 4train_iter = data.DataLoader(mnist_train, batch_size, shuffle=True,num_workers=get_dataloader_workers())
2.5 整合所有组件
这个函数包含了以上所有工作。
def load_data_fashion_mnist(batch_size, resize=None):"""下载Fashion-MNIST数据集,然后将其加载到内存中"""trans = [transforms.ToTensor()]if resize:trans.insert(0, transforms.Resize(resize)) #修改图片大小trans = transforms.Compose(trans)mnist_train = torchvision.datasets.FashionMNIST(root="../data", train=True, transform=trans, download=True)mnist_test = torchvision.datasets.FashionMNIST(root="../data", train=False, transform=trans, download=True)return (data.DataLoader(mnist_train, batch_size, shuffle=True,num_workers=get_dataloader_workers()),data.DataLoader(mnist_test, batch_size, shuffle=False,num_workers=get_dataloader_workers()))
相关文章:

动手学深度学习(三)线性神经网络—softmax回归
分类任务是对离散变量预测,通过比较分类的概率来判断预测的结果。 softmax回归和线性回归一样也是将输入特征与权重做线性叠加,但是softmax回归的输出值个数等于标签中的类别数,这样就可以用于预测分类问题。 分类问题和线性回归的区别&#…...

ios swift alert 自定义弹框 点击半透明部分弹框消失
文章目录 1.BaseAlertVC2.BindFrameNumAlertVC 1.BaseAlertVC import UIKitclass BaseAlertVC: GLBaseViewController {let centerView UIView()override func viewDidLoad() {super.viewDidLoad()view.backgroundColor UIColor(displayP3Red: 0, green: 0, blue: 0, alpha:…...

HCIP STP(生成树)
目录 一、STP概述 二、生成树协议原理 三、802.1D生成树 四、STP的配置BPDU 1、配置BPDU的报文格式 2、配置BPDU的工作过程 3、TCN BPDU 4、TCN BPDU的工作过程 五、STP角色选举 1、根网桥选举 2、根端口选举 3、指定端口选举 4、非指定端口选举 六、STP的接口状…...

【Unity开发必备】100多个 Unity 学习网址 资源 收藏整理大全【持续更新】
Unity 相关网站整理大全 众所周知,工欲善其事必先利其器,有一个好的工具可以让我们事半功倍,有一个好用的网站更是如此! 但是好用的网站真的太多了,收藏夹都满满的(但是几乎没打开用过😁)。 所以本文是对…...

Alpine Ridge控制器使其具备多种使用模式 - 英特尔发布雷电3接口:竟和USB Type-C统一了
同时又因为这建立在Type-C的基础上,雷电3也将利用现有的标准Type-C线缆引入有源支持。当使用Type-C的线缆时,雷电的速度就降到了20Gbps全双工——这与普通的Type-C的带宽相同——这是为了成本牺牲了一些带宽。可以比较一下,Type-C线的成本只有…...

容器——2.Collection 子接口之 List
文章目录 2.1. Arraylist 和 Vector 的区别?2.2. Arraylist 与 LinkedList 区别?2.2.1. 补充内容:双向链表和双向循环链表2.2.2. 补充内容:RandomAccess 接口 2.3 ArrayList 的扩容机制 2.1. Arraylist 和 Vector 的区别? ArrayList 是 List 的主要实现类,底层使…...

【工作记录】docker安装gitlab、重置密码@20230809
前言 本文记录下基于docker安装gitlab并重置管理员密码的过程。 作为记录的同时也希望能帮助到需要的朋友们。 搭建过程 1. 准备好docker环境并启动docker [rootslave-node1 docker-gitlab]# docker version Client:Version: 18.06.1-ceAPI version: 1.38…...
数据挖掘的基本概念和大数据的特点
数据挖掘是指从大量数据中提取有价值的信息或模式的过程。它通常使用计算机技术来分析数据,并利用统计学、机器学习、人工智能等方法来发现数据中的隐藏规律、趋势和关联性。 数据挖掘的基本概念包括以下几个方面: 数据预处理:对原始数据进行…...

LabVIEW开发分段反射器测试台
LabVIEW开发分段反射器测试台 随着对太空的观察需求越来越远,而不是当前技术(如哈勃望远镜)所能达到的,有必要增加太空望远镜主镜的尺寸。但是,增加主镜像的大小时存在几个问题。随着反射镜尺寸的增加,制造…...

二级python和二级c哪个简单,二级c语言和二级python
大家好,小编为大家解答二级c语言和二级office一起报可以吗的问题。很多人还不知道计算机二级c语言和python哪个好考,现在让我们一起来看看吧! 介绍Python有很多库和使用Qt编写的接口,这自然创建c调用Python的需求。一路摸索,充满艰辛的添加头…...
E: Package ‘curl‘ has no installation candidate/ E:软件包没有可用的安装源
解决方案: 访问etc/apt/source.list 修改或者添加安装源 不用版本的Linux 有不同的配置比如我的是Debian 12 其他版本的去搜索引擎搜索即可 vim /etc/apt/source.list 改成修改或添加 // 以下是官方示例 deb http://deb.debian.org/debian bookworm main non-…...

代理模式及常见的3种代理类型对比
代理模式及常见的3种代理类型对比 代理模式代理模式分类静态代理JDK动态代理CGLIB动态代理Fastclass机制 三种代理方式之间对比常见问题 代理模式 代理模式是一种设计模式,提供了对目标对象额外的访问方式,即通过代理对象访问目标对象,这样可…...
8.6 校招 内推 面经
绿泡泡: neituijunsir 交流裙,内推/实习/校招汇总表格 1、面经 | 车载测试-23 面经 | 车载测试-23 2、校招 | 荣耀2024届全球校园招聘启动(内推) 校招 | 荣耀2024届全球校园招聘启动(内推) 3、校招 |…...

【大数据之Flume】七、Flume进阶之自定义Sink
(1)概述: Sink 不断地轮询 Channel 中的事件且批量地移除它们,并将这些事件批量写入到存储或索引系统、或者被发送到另一个 Flume Agent。 Sink 是完全事务性的。在从 Channel 批量删除数据之前,每个 Sink 用 Chan…...
vue对于时间的处理
2023-08-05 11:25:45 假如这个就是我们要传的时间字符串 比如今天是2023-08-05(同一天):现在把这个时间字符串传入到 formatDate()这个方法,就会给你返回 11:25 比如今天是2023-08-06(前一天&a…...

Apache DolphinScheduler 3.1.8 版本发布,修复 SeaTunnel 相关 Bug
近日,Apache DolphinScheduler 发布了 3.1.8 版本。此版本主要基于 3.1.7 版本进行了 bug 修复,共计修复 16 个 bug, 1 个 doc, 2 个 chore。 其中修复了以下几个较为重要的问题: 修复在构建 SeaTunnel 任务节点的参数时错误的判断条件修复 …...

科技云报道:一波未平一波又起?AI大模型再出邪恶攻击工具
AI大模型的快速向前奔跑,让我们见识到了AI的无限可能,但也展示了AI在虚假信息、深度伪造和网络攻击方面的潜在威胁。 据安全分析平台Netenrich报道,近日,一款名为FraudGPT的AI工具近期在暗网上流通,并被犯罪分子用于编…...

深度对话|如何设计合适的网络经济激励措施
近日,我们与Mysten Labs的首席经济学家Alonso de Gortari进行了对话,讨论了如何在网络运营商和参与者之间找到激励措施的平衡,以及Sui的经济如何不断发展。 是什么让您选择将自己的经济学背景应用于区块链和Web3领域? 起初&…...

opencv带GStreamer之Windows编译
目录 1、下载GStreamer和安装2. GSTReamer CMake配置3. 验证是否配置成功 1、下载GStreamer和安装 下载地址如下: gstreamer-1.0-msvc-x86_64-1.18.2.msi gstreamer-1.0-devel-msvc-x86_64-1.18.2.msi 安装目录无要求,主要是安装完设置环境变量 xxx\1…...

Java并发编程之锁的升级
Java 中的锁机制是多线程编程中的一部分。锁一共有4种状态,级别从低到高依次是:无锁状态、偏向锁状态、轻量级锁状态和重量级锁状态,这几个状态会随着竞争情况逐渐升级。 锁可以升级但不能降级,意味着偏向锁升级成轻量级锁后不能…...
uniapp 对接腾讯云IM群组成员管理(增删改查)
UniApp 实战:腾讯云IM群组成员管理(增删改查) 一、前言 在社交类App开发中,群组成员管理是核心功能之一。本文将基于UniApp框架,结合腾讯云IM SDK,详细讲解如何实现群组成员的增删改查全流程。 权限校验…...

AI-调查研究-01-正念冥想有用吗?对健康的影响及科学指南
点一下关注吧!!!非常感谢!!持续更新!!! 🚀 AI篇持续更新中!(长期更新) 目前2025年06月05日更新到: AI炼丹日志-28 - Aud…...
云原生核心技术 (7/12): K8s 核心概念白话解读(上):Pod 和 Deployment 究竟是什么?
大家好,欢迎来到《云原生核心技术》系列的第七篇! 在上一篇,我们成功地使用 Minikube 或 kind 在自己的电脑上搭建起了一个迷你但功能完备的 Kubernetes 集群。现在,我们就像一个拥有了一块崭新数字土地的农场主,是时…...

【kafka】Golang实现分布式Masscan任务调度系统
要求: 输出两个程序,一个命令行程序(命令行参数用flag)和一个服务端程序。 命令行程序支持通过命令行参数配置下发IP或IP段、端口、扫描带宽,然后将消息推送到kafka里面。 服务端程序: 从kafka消费者接收…...
k8s从入门到放弃之Ingress七层负载
k8s从入门到放弃之Ingress七层负载 在Kubernetes(简称K8s)中,Ingress是一个API对象,它允许你定义如何从集群外部访问集群内部的服务。Ingress可以提供负载均衡、SSL终结和基于名称的虚拟主机等功能。通过Ingress,你可…...

(二)TensorRT-LLM | 模型导出(v0.20.0rc3)
0. 概述 上一节 对安装和使用有个基本介绍。根据这个 issue 的描述,后续 TensorRT-LLM 团队可能更专注于更新和维护 pytorch backend。但 tensorrt backend 作为先前一直开发的工作,其中包含了大量可以学习的地方。本文主要看看它导出模型的部分&#x…...
【算法训练营Day07】字符串part1
文章目录 反转字符串反转字符串II替换数字 反转字符串 题目链接:344. 反转字符串 双指针法,两个指针的元素直接调转即可 class Solution {public void reverseString(char[] s) {int head 0;int end s.length - 1;while(head < end) {char temp …...

2025盘古石杯决赛【手机取证】
前言 第三届盘古石杯国际电子数据取证大赛决赛 最后一题没有解出来,实在找不到,希望有大佬教一下我。 还有就会议时间,我感觉不是图片时间,因为在电脑看到是其他时间用老会议系统开的会。 手机取证 1、分析鸿蒙手机检材&#x…...
LLM基础1_语言模型如何处理文本
基于GitHub项目:https://github.com/datawhalechina/llms-from-scratch-cn 工具介绍 tiktoken:OpenAI开发的专业"分词器" torch:Facebook开发的强力计算引擎,相当于超级计算器 理解词嵌入:给词语画"…...
使用Matplotlib创建炫酷的3D散点图:数据可视化的新维度
文章目录 基础实现代码代码解析进阶技巧1. 自定义点的大小和颜色2. 添加图例和样式美化3. 真实数据应用示例实用技巧与注意事项完整示例(带样式)应用场景在数据科学和可视化领域,三维图形能为我们提供更丰富的数据洞察。本文将手把手教你如何使用Python的Matplotlib库创建引…...