当前位置: 首页 > news >正文

深度学习(五)softmax 回归之:分类算法介绍,如何加载 Fashion-MINIST 数据集

Softmax 回归

基本原理

回归和分类,是两种深度学习常用方法。回归是对连续的预测(比如我预测根据过去开奖列表下次双色球号),分类是预测离散的类别(手写语音识别,图片识别)。

1699720169075

现在我们已经对回归的处理有一定的理解了,如何过渡到分类呢?

假设我们有 n 类,首先我们要编码这些类让他们变成数据。所有类变成一个列向量。

y = [ y 1 , y 2 , . . . y n ] T y=[y_1,y_2,...y_n]^T y=[y1,y2,...yn]T

有一个数据属于第 i 类,那么他的列向量就是:

y = [ 0 , 0 , . . . , 1 , . . . , 0 , 0 ] T y=[0,0,...,1,...,0,0]^T y=[0,0,...,1,...,0,0]T

也就是只有他所在的那个类的元素=1.

可以用均方损失训练,通过概率判断最终选用哪一个。

Softmax 回归就是一种分类方式(回归问题在多分类上的推广)。首先确定输入特征数和输出类别数。比如上图中我们有4个特征和3个可能的类别,那么计算各自概率的公式包括3个线性回归:

外链图片转存失败,源站可能有防盗链机制,建议将图片保存下来直接上传

可以看出 Softmax 是全连接的单层神经网络。

外链图片转存失败,源站可能有防盗链机制,建议将图片保存下来直接上传

我们让所有输出结果归一化后,从中选择出最大可能的,置信度最高的分类结果。

image-20231112100423488

采用 e 的指数可以让值全变为非负。

用真实的概率向量-我们预测得到的概率向量就是损失。真实值就是只有一个1的列向量。

交叉熵损失:

image-20231112101259670

可见**分类问题,我们不关心对非正确的预测值,只关心正确预测值是否足够大。**因为正确值是只有一个元素为1的列向量。

常用的损失函数

L2 Loss:均方损失。

image-20231112101555142

L1 Loss:绝对值损失。

image-20231112101829868

L2 梯度是一条倾斜直线,对于梯度下降算法等更为合适;L1 是一个跳变,梯度要么 -1 要么 1. 如图是 L1 L2 的梯度。

image-20231112102551104

我们可以结合两者,得到一个新的损失函数(鲁棒损失 Huber Robust):

KaTeX parse error: {equation} can be used only in display mode.

image-20231112102721527

图像分类数据集

MINIST 是一个常用图像分类数据集,但是过于简单。后来的 upgrade 版叫 Fashion-MINIST(服装分类).

首先,我们研究研究怎么加载训练数据集,以便后面测试算法用。

# 导包
%matplotlib inline
import torch
import torchvision
from torch.utils import data
from torchvision import transforms
from d2l import torch as d2ld2l.use_svg_display()d2l.use_svg_display()# 下载数据集并读取到内存
trans = transforms.ToTensor()
mnist_train = torchvision.datasets.FashionMNIST(root="../data", train=True, transform=trans, download=True)		# 训练数据集
mnist_test = torchvision.datasets.FashionMNIST(root="../data", train=False, transform=trans, download=True)	# 测试数据集用于评估性能# 定义函数用于返回对应索引的标签
def get_fashion_mnist_labels(labels):  #@save"""返回Fashion-MNIST数据集的文本标签"""text_labels = ['t-shirt', 'trouser', 'pullover', 'dress', 'coat','sandal', 'shirt', 'sneaker', 'bag', 'ankle boot']return [text_labels[int(i)] for i in labels]# 图像可视化,让结果看着更直观,比如下面那个绿色图的样子
def show_images(imgs, num_rows, num_cols, titles=None, scale=1.5):  #@save"""绘制图像列表"""figsize = (num_cols * scale, num_rows * scale)_, axes = d2l.plt.subplots(num_rows, num_cols, figsize=figsize)axes = axes.flatten()for i, (ax, img) in enumerate(zip(axes, imgs)):if torch.is_tensor(img):# 图片张量ax.imshow(img.numpy())else:# PIL图片ax.imshow(img)ax.axes.get_xaxis().set_visible(False)ax.axes.get_yaxis().set_visible(False)if titles:ax.set_title(titles[i])return axes# 我们先读一点数据集看看啥样的
X, y = next(iter(data.DataLoader(mnist_train, batch_size=18)))
show_images(X.reshape(18, 28, 28), 2, 9, titles=get_fashion_mnist_labels(y));

1699980345931

# 通过内置数据加载器读取一批量数据,自动随机打乱读取,不需要我们自己定义
batch_size = 256def get_dataloader_workers():  #@save"""使用4个进程来读取数据"""return 4train_iter = data.DataLoader(mnist_train, batch_size, shuffle=True,num_workers=get_dataloader_workers())

测量以上用时基本2-3s。

总结整合以上数据读取过程,代码如下:

def load_data_fashion_mnist(batch_size, resize=None):  #@save"""下载Fashion-MNIST数据集,然后将其加载到内存中"""trans = [transforms.ToTensor()]if resize:trans.insert(0, transforms.Resize(resize))trans = transforms.Compose(trans)mnist_train = torchvision.datasets.FashionMNIST(root="../data", train=True, transform=trans, download=True)mnist_test = torchvision.datasets.FashionMNIST(root="../data", train=False, transform=trans, download=True)return (data.DataLoader(mnist_train, batch_size, shuffle=True,num_workers=get_dataloader_workers()),data.DataLoader(mnist_test, batch_size, shuffle=False,num_workers=get_dataloader_workers()))

加载图像还可以调整其大小。

相关文章:

深度学习(五)softmax 回归之:分类算法介绍,如何加载 Fashion-MINIST 数据集

Softmax 回归 基本原理 回归和分类,是两种深度学习常用方法。回归是对连续的预测(比如我预测根据过去开奖列表下次双色球号),分类是预测离散的类别(手写语音识别,图片识别)。 现在我们已经对回…...

单稳态中间继电器\UEG/A-2H/220V 8A导轨安装 JOSEF约瑟

UEG系列中间继电器 UEG/A-2H2D中间继电器UEG/A-4H4D中间继电器UEG/A-2D中间继电器 UEG/A-2H中间继电器UEG/A-4H中间继电器UEG/A-4D中间继电器 UEG/A-6H中间继电器UEG/A-6D中间继电器UEG/A-8H中间继电器 UEG/A-10D中间继电器UEG/A-10H中间继电器UEG/A-2DPDT中间继电器 UEG/A-4DP…...

2311rust到20版本更新

Rust1.9 Rust1.9中最大的变化是稳定了包括停止恐慌启动的展开过程方法的std::panic模块: use std::panic; let result panic::catch_unwind(|| {println!("hello!"); }); assert!(result.is_ok()); let result panic::catch_unwind(|| {panic!("oh no!"…...

基于Spring、SpringMVC、MyBatis的漫画网站

文章目录 项目介绍主要功能截图:部分代码展示设计总结项目获取方式🍅 作者主页:超级无敌暴龙战士塔塔开 🍅 简介:Java领域优质创作者🏆、 简历模板、学习资料、面试题库【关注我,都给你】 🍅文末获取源码联系🍅 项目介绍 基于Spring、SpringMVC、MyBatis的漫画网…...

MySQL数据库八股文

MySQL数据库八股文 第一章 数据库基础 1. 数据库概念 数据库是存储数据的仓库,数据库管理系统是操纵和管理数据库的大型软件(如MySQL,InnoDB是其默认的存储引擎),SQL是操作关系型数据库的编程语言。 2. SQL语法与分…...

利用WebSocket +MQ发送紧急订单消息,并在客户端收到消息的用户的页面自动刷新列表

背景:在原有通知公告的基础上,把通知公共的推送服务修改为其他业务收到紧急订单发送公告到消息队列MQ,然后在js中创建一个socket去监听公告,收到公告后刷新所有在订单页面的用户的页面列表(重点就是用户在收到紧急订单…...

R语言——taxize(第一部分)

ropensci 系列之 taxize (中译手册) taxize 包1. taxize支持的网络数据源简介目前支持的API:针对Catalogue of Life(COL) 2. 浅尝 taxize 的一些使用例子2.1. **从NCBI上获取唯一的分类标识符**2.2. **获取分类信息**2…...

【Spring Cloud】黑马头条 用户服务创建、登录功能实现

点击去看上一篇 一、创建用户 model 1.创建用户数据库库 leadnews_user 核心表 ap_user 建库建表语句 这里一定要使用 navicat,执行SQL 文件,以防止 cmd 中的编码问题 先将 SQL 语句,保存在电脑中,再使用 navicat 打开 CREATE…...

聚观早报 |英伟达发布H200;夸克发布自研大模型

【聚观365】11月15日消息 英伟达发布H200 夸克发布自研大模型 iQOO 12系列开启销售 红魔9 Pro配置细节 禾赛科技第三季度营收4.5亿元 英伟达发布H200 全球市值最高的芯片制造商英伟达公司,正在升级其H100人工智能处理器,为这款产品增加更多功能&am…...

15项基本SCADA技术技能

1. 人机界面 人机界面是将操作员连接到设备、系统或机器的仪表板或用户界面。 以下是 hmi 在 scada 技术人员简历中的使用方式: 完成了查尔斯湖废水处理厂和提升站的完整 HMI 图形界面。对加油系统、加油车、PLC、HMI、触摸屏进行故障排除和维修。对 Horner HMI …...

Golang 发送邮件

Go 有内置好的本地库可以发送邮件,在 GitHub 上也有别人写好的第三方包可以发送邮件。 本文将分别介绍一下这两种发送邮件的方式。 1、内置的net/smtp 为了更好的模拟发送邮件,推荐一个邮件测试工具:MailHog,MailHog 是面向开发…...

【ARM Trace32(劳特巴赫) 使用介绍 5-- Trace32 通过 JTAG 命令获取数据寄存器 IDCODE的值】

请阅读【ARM Coresight SoC-400/SoC-600 专栏导读】 文章目录 Trace JTAG Command LineTrace32 JTAG 数据发送命令Trace32 JTAG 数据接收命令Trace32 数据访问修饰符Trace32 IDCODE 脚本实例Trace32 APITrace JTAG Command Line Trace32 JTAG 数据发送命令 JTAG.SHIFTTMS <…...

Python之while/for,continue/break

定义一个随机数&#xff1a; import random numrandom.randint(1,10) while循环&#xff1a; while 条件(): 条件满足时&#xff0c;做的事情1 条件满足时&#xff0c;做的事情2 ...... for循环&#xff1a; for 变量 in range(10): 循环需要执行的代码 else: 循环结束时&…...

卷积神经网络(CNN)衣服图像分类的实现

文章目录 前期工作1. 设置GPU&#xff08;如果使用的是CPU可以忽略这步&#xff09;我的环境&#xff1a; 2. 导入数据3.归一化4.调整图片格式5. 可视化 二、构建CNN网络模型三、编译模型四、训练模型五、预测六、模型评估 前期工作 1. 设置GPU&#xff08;如果使用的是CPU可以…...

odoo16前端框架源码阅读——env.js

env.js&#xff08;env的初始化以及服务的加载&#xff09; 路径&#xff1a;addons\web\static\src\env.js 这个文件的作用就是初始化env&#xff0c;主要是加载所有的服务。如orm, title, dialog等。 1、env.js 的加载时机 前文我们讲过前端的启动函数&#xff0c;start.…...

浙大恩特客户资源管理系统 SQL注入漏洞复现

0x01 产品简介 浙大恩特客户资源管理系统是一款针对企业客户资源管理的软件产品。该系统旨在帮助企业高效地管理和利用客户资源&#xff0c;提升销售和市场营销的效果。 0x02 漏洞概述 浙大恩特客户资源管理系统中T0140_editAction.entweb接口处存在SQL注入漏洞&#xff0c;未…...

ESP32网络开发实例-BME280传感器数据保存到InfluxDB时序数据库

BME280传感器数据保存到InfluxDB时序数据库 文章目录 BME280传感器数据保存到InfluxDB时序数据库1、BM280和InfluxDB介绍2、软件准备3、硬件准备4、代码实现在本文中,将详细介绍如何将BME280传感器数据上传到InfluxDB中,方便后期数据处理。 1、BM280和InfluxDB介绍 InfluxDB…...

C++中sort()函数的greater<int>()参数

目录 1 基础知识2 模板3 工程化 1 基础知识 sort()函数中的greater<int>()参数表示将容器内的元素降序排列。不填此参数&#xff0c;默认表示升序排列。 vector<int> a {1,2,3}; sort(a.begin(), a.end(), greater<int>()); //将a降序排列 sort(a.begin()…...

2024有哪些免费的mac苹果电脑内存清理工具?

在我们日常使用苹果电脑的过程中&#xff0c;随着时间的推移&#xff0c;可能会发现设备的速度变慢了&#xff0c;甚至出现卡顿的现象。其中一个常见的原因就是程序占用内存过多&#xff0c;导致系统无法高效地运行。那么&#xff0c;苹果电脑内存怎么清理呢&#xff1f;本文将…...

线性表的概念

目录 1.什么叫线性表2.区分线性表的题 1.什么叫线性表 线性表&#xff08;linear list&#xff09;是n个具有相同特性的数据元素的有限序列。 线性表是一种在实际中广泛使用的数据结构&#xff0c;常见的线性表&#xff1a;顺序表、链表、栈、队列、字符串… 线性表在逻辑上是…...

SkyWalking 10.2.0 SWCK 配置过程

SkyWalking 10.2.0 & SWCK 配置过程 skywalking oap-server & ui 使用Docker安装在K8S集群以外&#xff0c;K8S集群中的微服务使用initContainer按命名空间将skywalking-java-agent注入到业务容器中。 SWCK有整套的解决方案&#xff0c;全安装在K8S群集中。 具体可参…...

Cinnamon修改面板小工具图标

Cinnamon开始菜单-CSDN博客 设置模块都是做好的&#xff0c;比GNOME简单得多&#xff01; 在 applet.js 里增加 const Settings imports.ui.settings;this.settings new Settings.AppletSettings(this, HTYMenusonichy, instance_id); this.settings.bind(menu-icon, menu…...

数据链路层的主要功能是什么

数据链路层&#xff08;OSI模型第2层&#xff09;的核心功能是在相邻网络节点&#xff08;如交换机、主机&#xff09;间提供可靠的数据帧传输服务&#xff0c;主要职责包括&#xff1a; &#x1f511; 核心功能详解&#xff1a; 帧封装与解封装 封装&#xff1a; 将网络层下发…...

Unit 1 深度强化学习简介

Deep RL Course ——Unit 1 Introduction 从理论和实践层面深入学习深度强化学习。学会使用知名的深度强化学习库&#xff0c;例如 Stable Baselines3、RL Baselines3 Zoo、Sample Factory 和 CleanRL。在独特的环境中训练智能体&#xff0c;比如 SnowballFight、Huggy the Do…...

Android15默认授权浮窗权限

我们经常有那种需求&#xff0c;客户需要定制的apk集成在ROM中&#xff0c;并且默认授予其【显示在其他应用的上层】权限&#xff0c;也就是我们常说的浮窗权限&#xff0c;那么我们就可以通过以下方法在wms、ams等系统服务的systemReady()方法中调用即可实现预置应用默认授权浮…...

在 Spring Boot 中使用 JSP

jsp&#xff1f; 好多年没用了。重新整一下 还费了点时间&#xff0c;记录一下。 项目结构&#xff1a; pom: <?xml version"1.0" encoding"UTF-8"?> <project xmlns"http://maven.apache.org/POM/4.0.0" xmlns:xsi"http://ww…...

QT开发技术【ffmpeg + QAudioOutput】音乐播放器

一、 介绍 使用ffmpeg 4.2.2 在数字化浪潮席卷全球的当下&#xff0c;音视频内容犹如璀璨繁星&#xff0c;点亮了人们的生活与工作。从短视频平台上令人捧腹的搞笑视频&#xff0c;到在线课堂中知识渊博的专家授课&#xff0c;再到影视平台上扣人心弦的高清大片&#xff0c;音…...

【java】【服务器】线程上下文丢失 是指什么

目录 ■前言 ■正文开始 线程上下文的核心组成部分 为什么会出现上下文丢失&#xff1f; 直观示例说明 为什么上下文如此重要&#xff1f; 解决上下文丢失的关键 总结 ■如果我想在servlet中使用线程&#xff0c;代码应该如何实现 推荐方案&#xff1a;使用 ManagedE…...

精益数据分析(98/126):电商转化率优化与网站性能的底层逻辑

精益数据分析&#xff08;98/126&#xff09;&#xff1a;电商转化率优化与网站性能的底层逻辑 在电子商务领域&#xff0c;转化率与网站性能是决定商业成败的核心指标。今天&#xff0c;我们将深入解析不同类型电商平台的转化率基准&#xff0c;探讨页面加载速度对用户行为的…...

6.9本日总结

一、英语 复习默写list11list18&#xff0c;订正07年第3篇阅读 二、数学 学习线代第一讲&#xff0c;写15讲课后题 三、408 学习计组第二章&#xff0c;写计组习题 四、总结 明天结束线代第一章和计组第二章 五、明日计划 英语&#xff1a;复习l默写sit12list17&#…...