当前位置: 首页 > news >正文

线性神经网络——softmax 回归随笔【深度学习】【PyTorch】【d2l】

文章目录

    • 3.2、softmax 回归
      • 3.2.1、softmax运算
      • 3.2.2、交叉熵损失函数
      • 3.2.3、PyTorch 从零实现 softmax 回归
      • 3.2.4、简单实现 softmax 回归

在这里插入图片描述

3.2、softmax 回归

3.2.1、softmax运算

在这里插入图片描述

softmax 函数是一种常用的激活函数,用于将实数向量转换为概率分布向量。它在多类别分类问题中起到重要的作用,并与交叉熵损失函数结合使用。

y ^ = s o f t m a x ( o ) 其中     y ^ i = e x p ( o j ) ∑ k e x p ( o k ) \hat{y} = softmax(o) \ \ \ \ \ 其中\ \ \ \ \hat{y}_i = \frac{exp(o_j)}{\sum_{k}exp(o_k)} y^=softmax(o)     其中    y^i=kexp(ok)exp(oj)

其中,O为小批量的未规范化的预测, Y ^ \hat{Y} Y^w为输出概率,是一个正确的概率分布【 ∑ y i = 1 \sum{y_i} =1 yi=1

3.2.2、交叉熵损失函数

通过测量给定模型编码的比特位,来衡量两概率分布之间的差异,是分类问题中常用的 loss 函数。

H ( P , Q ) = − Σ P ( x ) ∗ l o g ( Q ( x ) ) H(P, Q) = -Σ P(x) * log(Q(x)) H(P,Q)=ΣP(x)log(Q(x))

真实概率分布是从哪里得知的?

真实标签的概率分布是由数据集中的标签信息提供的,通常使用单热编码表示。

softmax() 如何与交叉熵函数搭配的?

softmax 函数与交叉熵损失函数常用于多分类任务中。softmax 函数用于将模型输出转化为概率分布形式,交叉熵损失函数用于衡量模型输出概率分布与真实标签的差异,并通过优化算法来最小化损失函数,从而训练出更准确的分类模型。

3.2.3、PyTorch 从零实现 softmax 回归

(非完整代码)

#在 Notebook 中内嵌绘图
%matplotlib inline
import torch
import torchvision
from torch.utils import data
from torchvision import transforms
from d2l import torch as d2l#,将图形显示格式设置为 SVG 格式,以在 Notebook 中以矢量图形的形式显示图像。这有助于提高图像的清晰度和可缩放性。
d2l .use_svg_display()

在线下载数据集 Fashion-MNIST

#将图像数据转换为张量形式
trans = transforms.ToTensor()
mnist_train = torchvision.datasets.FashionMNIST(root="../data",train=True,transform=trans,download=True)
mnist_test = torchvision.datasets.FashionMNIST(root="../data",train=False,transform =trans,download=True)len(mnist_train),len(mnist_test)

绘图(略)

读取小批量数据集

batch_size = 256def get_dataloader_workers():"""使用4进程读取"""return 4train_iter = data.DataLoader(mnist_train,batch_size,shuffle=True,num_workers=get_dataloader_workers())
timer = d2l.Timer()
for X,y in train_iter:continue
print(f'{timer.stop():.2f}sec')

定义softmax操作

def softmax(X):X_exp = torch.exp(X)partition = X_exp.sum(1, keepdim=True)return X_exp / partition  # 这里应用了广播机制

定义损失函数

def cross_entropy(y_hat, y):return - torch.log(y_hat[range(len(y_hat)), y])cross_entropy(y_hat, y)

分类精度

def accuracy(y_hat, y):  #@save"""计算预测正确的数量"""if len(y_hat.shape) > 1 and y_hat.shape[1] > 1:y_hat = y_hat.argmax(axis=1)cmp = y_hat.type(y.dtype) == yreturn float(cmp.type(y.dtype).sum())

评估

def evaluate_accuracy(net, data_iter):  #@save"""计算在指定数据集上模型的精度"""if isinstance(net, torch.nn.Module):net.eval()  # 将模型设置为评估模式metric = Accumulator(2)  # 正确预测数、预测总数with torch.no_grad():for X, y in data_iter:metric.add(accuracy(net(X), y), y.numel())return metric[0] / metric[1]
class Accumulator:  #@save"""在n个变量上累加"""def __init__(self, n):self.data = [0.0] * ndef add(self, *args):self.data = [a + float(b) for a, b in zip(self.data, args)]def reset(self):self.data = [0.0] * len(self.data)def __getitem__(self, idx):return self.data[idx]

3.2.4、简单实现 softmax 回归

导入前面已下载数据集 Fashion-MNIST

import torch 
from torch import nn
from d2l import torch as d2lbatch_size =256
train_iter,test_iter = d2l.load_data_fashion_mnist(batch_size)

初始化模型

#nn.Flatten() 层的作用是将输入数据展平,将二维输入(如图像)转换为一维向量。因为线性层(nn.Linear)通常期望接收一维输入。
#nn.Linear(784,10) 将输入特征从 784 维降低到 10 维,用于图像分类问题中的 10 个类别的预测   784维向量->10维向量
net = nn.Sequential(nn.Flatten(),nn.Linear(784,10))def init_weights(m):if type(m) == nn.Linear:nn.init.normal_(m.weight,std=0.01)net.apply(init_weights);
#计算交叉熵损失函数,用于衡量模型预测与真实标签之间的差异。参数 reduction 控制了损失的计算方式。
#reduction='none' 表示不进行损失的降维或聚合操作,即返回每个样本的独立损失值。
loss = nn.CrossEntropyLoss(reduction='none')

优化算法

trainer = torch.optim.SGD(net.parameters(),lr=0.1)

训练

num_epochs = 10
d2l.train_ch3(net,train_iter,test_iter,loss,num_epochs,trainer)

相关文章:

线性神经网络——softmax 回归随笔【深度学习】【PyTorch】【d2l】

文章目录 3.2、softmax 回归3.2.1、softmax运算3.2.2、交叉熵损失函数3.2.3、PyTorch 从零实现 softmax 回归3.2.4、简单实现 softmax 回归 3.2、softmax 回归 3.2.1、softmax运算 softmax 函数是一种常用的激活函数,用于将实数向量转换为概率分布向量。它在多类别…...

【Nodejs】Node.js开发环境安装

1.版本介绍 在命令窗口中输入 node -v 可以查看版本 0.x 完全不技术 ES64.x 部分支持 ES6 特性5.x 部分支持ES6特性(比4.x多些),属于过渡产品,现在来说应该没有什么理由去用这个了6.x 支持98%的 ES6 特性8.x 支持 ES6 特性 2.No…...

梅尔频谱(Mel spectrum)简介及Python实现

梅尔频谱(Mel spectrum)简介及Python实现 1. 梅尔频谱(Mel spectrum)简介2. Python可视化测试3.频谱可视化3.1 Mel 频谱可视化3.2 STFT spectrum参考文献资料1. 梅尔频谱(Mel spectrum)简介 在信号处理上,声信号(噪声信号)是一种重要的传感监测手段。对于语音分类任务…...

【数据结构】实验六:队列

实验六 队列 一、实验目的与要求 1)熟悉C/C语言(或其他编程语言)的集成开发环境; 2)通过本实验加深对队列的理解,熟悉基本操作; 3) 结合具体的问题分析算法时间复杂度。 二、…...

【Linux线程】第一章||理解线程概念+创建一个线程(附代码加讲解)

线程概念 🌵什么是线程🌲线程和进程的关系🎄线程有以下特点:🌳 线程的优点🌴 线程的缺点🌱线程异常🌿线程用途 ☘️手动创建一个进程🍀运行 🌵什么是线程 在L…...

Android进阶之微信扫码登录

遇到新需求要搭建微信扫码登录功能,这篇文章是随着我的编码过程一并写的,希望能够帮助有需求的人和以后再次用到此功能的自己。 首先想到的就是百度各种文章,当然去开发者平台申请AppID和密钥是必不可少的,等注册好发现需要创建应用以及审核(要官网,流程图及其他信息),想着先写…...

macOS Monterey 12.6.8 (21G725) Boot ISO 原版可引导镜像

macOS Monterey 12.6.8 (21G725) Boot ISO 原版可引导镜像 本站下载的 macOS 软件包,既可以拖拽到 Applications(应用程序)下直接安装,也可以制作启动 U 盘安装,或者在虚拟机中启动安装。另外也支持在 Windows 和 Lin…...

Unity自定义后处理——用偏导数求图片颜色边缘

大家好,我是阿赵。   继续介绍屏幕后处理效果的做法。这次介绍一下用偏导数求图形边缘的技术。 一、原理介绍 先来看例子吧。   这个例子看起来好像是要给模型描边。之前其实也介绍过很多描边的方法,比如沿着法线方向放大模型,或者用Ndo…...

本地Git仓库和GitHub仓库SSH传输

SSH创建命令解释 ssh-keygen 用于创建密钥的程序 -m PEM 将密钥的格式设为 PEM -t rsa 要创建的密钥类型,本例中为 RSA 格式 -b 4096 密钥的位数,本例中为 4096 -C “azureusermyserver” 追加到公钥文件末尾以便于识别的注释。 通常以电子邮件地址…...

【C++11】——右值引用、移动语义

目录 1. 基本概念 1.1 左值与左值引用 1.2 右值和右值引用 1.3 左值引用与右值引用 2. 右值引用实用场景和意义 2.1 左值引用的使用场景 2.2 左值引用的短板 2.3 右值引用和移动语义 2.3.1 移动构造 2.3.2 移动赋值 2.3.3 编译器做的优化 2.3.4 总结 2.4 右值引用…...

消息服务概述

消息服务的作用: 在多数应用尤其是分布式系统中,消息服务是不可或缺的重要部分,它使用起来比较简单,同时解决了不少难题,例如异步处理、应用解耦、流量削锋、分布式事务管理等,使用消息服务可以实现一个高…...

【Spring Boot】Web开发 — 数据验证

Web开发 — 数据验证 对于应用系统而言,任何客户端传入的数据都不是绝对安全有效的,这就要求我们在服务端接收到数据时也对数据的有效性进行验证,以确保传入的数据安全正确。接下来介绍Spring Boot是如何实现数据验证的。 1.Hibernate Vali…...

技术分享 | App常见bug解析

功能Bug 内容显示错误 前端页面展示的内容有误。 这种错误的产生有两种可能 1、前端代码写的文案错误 2、接口返回值错误 功能错误 功能错误是在测试过程中最常见的类型之一,也就是产品的功能没有实现。比如图中的公众号登录不成功的问题。 界面展示错乱 产…...

树莓派Pico|RP2040|使用SWD进行调试|构建 “Hello World“ debug版本

文章目录 使用SWD进行调试构建 "Hello World" debug版本安装 GDB使用 GDB 和 OpenOCD 来 debug Hello World TIP重要提示 使用SWD进行调试 基于rp2040的板上的SWD端口重置,加载和运行代码,如树莓派Pico可用于交互式调试已加载的程序。这包括:…...

Ubuntu18.04 下配置Clion

配置Clion 安装gcc、g、make Ubuntu中用到的编译工具是gcc©,g(C),make(连接)。因此只需安装对应的工具包即可。Ubuntu下使用命令安装这些包: (1)安装gcc sudo apt install gcc&am…...

数据库管理-第九十四期 19c OCM之路-第四堂(02)(20230725)

第九十四期 19c OCM之路-第四堂(02)(20230725) 第四堂继续! 考点3:SQL statement tuning SQL语句调优 收集Schema统计信息 exec dbms_stats.gather_schems_stats(HR);开启制定表索引监控 create index…...

以智慧监测模式守护燃气安全 ,汉威科技“传感芯”凸显智慧力

城市燃气工程作为城市基建的重要组成部分,与城市居民生活、工业生产紧密相关。提升城市燃气服务质量和安全水平,也一直是政府和民众关注的大事。然而,近年来居民住宅、餐饮等工商业场所燃气事故频发,时刻敲响的警钟也折射出我国在…...

【阅读笔记】一种暗通道优先的快速自动白平衡算法

解决问题: 自动白平衡算法中存在白色区域检测错误导致白平衡失效的问题,作者提出了一种基于暗通道优先的白平衡算法。 算法思想: 图像中白色区域或者高饱和度区域的光线透射率较低,根据以上特性利用暗通道法计算图像中白色区域。 算法概述: 作者使用何凯明提出的基于暗…...

OpenStack之云主机管理

一&#xff09;必备知识 1.云主机与快照管理 a-云主机管理 云主机管理是OpenStack云计算平台的核心功能&#xff0c;通常&#xff0c;云主机的管理包括创建、删除、查询等。可使用以下命令对OpenStack的云主机进行管理&#xff1a; openstack server <操作><云主机…...

Linux系列---【Ubuntu 20.04安装KVM】

Ubuntu 20.04安装KVM 一、安装kvm 1.安装kvm sudo apt install qemu-kvm libvirt-daemon-system libvirt-clients bridge-utils 2. 将当前用户添加至libvirt 、 kvm组 sudo adduser $USER libvirt sudo adduser $USER kvm 3.验证安装 virsh list --all 4.启动libvert sudo syst…...

浅谈 React Hooks

React Hooks 是 React 16.8 引入的一组 API&#xff0c;用于在函数组件中使用 state 和其他 React 特性&#xff08;例如生命周期方法、context 等&#xff09;。Hooks 通过简洁的函数接口&#xff0c;解决了状态与 UI 的高度解耦&#xff0c;通过函数式编程范式实现更灵活 Rea…...

网络六边形受到攻击

大家读完觉得有帮助记得关注和点赞&#xff01;&#xff01;&#xff01; 抽象 现代智能交通系统 &#xff08;ITS&#xff09; 的一个关键要求是能够以安全、可靠和匿名的方式从互联车辆和移动设备收集地理参考数据。Nexagon 协议建立在 IETF 定位器/ID 分离协议 &#xff08;…...

MPNet:旋转机械轻量化故障诊断模型详解python代码复现

目录 一、问题背景与挑战 二、MPNet核心架构 2.1 多分支特征融合模块(MBFM) 2.2 残差注意力金字塔模块(RAPM) 2.2.1 空间金字塔注意力(SPA) 2.2.2 金字塔残差块(PRBlock) 2.3 分类器设计 三、关键技术突破 3.1 多尺度特征融合 3.2 轻量化设计策略 3.3 抗噪声…...

【Linux】shell脚本忽略错误继续执行

在 shell 脚本中&#xff0c;可以使用 set -e 命令来设置脚本在遇到错误时退出执行。如果你希望脚本忽略错误并继续执行&#xff0c;可以在脚本开头添加 set e 命令来取消该设置。 举例1 #!/bin/bash# 取消 set -e 的设置 set e# 执行命令&#xff0c;并忽略错误 rm somefile…...

.Net框架,除了EF还有很多很多......

文章目录 1. 引言2. Dapper2.1 概述与设计原理2.2 核心功能与代码示例基本查询多映射查询存储过程调用 2.3 性能优化原理2.4 适用场景 3. NHibernate3.1 概述与架构设计3.2 映射配置示例Fluent映射XML映射 3.3 查询示例HQL查询Criteria APILINQ提供程序 3.4 高级特性3.5 适用场…...

P3 QT项目----记事本(3.8)

3.8 记事本项目总结 项目源码 1.main.cpp #include "widget.h" #include <QApplication> int main(int argc, char *argv[]) {QApplication a(argc, argv);Widget w;w.show();return a.exec(); } 2.widget.cpp #include "widget.h" #include &q…...

C# 求圆面积的程序(Program to find area of a circle)

给定半径r&#xff0c;求圆的面积。圆的面积应精确到小数点后5位。 例子&#xff1a; 输入&#xff1a;r 5 输出&#xff1a;78.53982 解释&#xff1a;由于面积 PI * r * r 3.14159265358979323846 * 5 * 5 78.53982&#xff0c;因为我们只保留小数点后 5 位数字。 输…...

IP如何挑?2025年海外专线IP如何购买?

你花了时间和预算买了IP&#xff0c;结果IP质量不佳&#xff0c;项目效率低下不说&#xff0c;还可能带来莫名的网络问题&#xff0c;是不是太闹心了&#xff1f;尤其是在面对海外专线IP时&#xff0c;到底怎么才能买到适合自己的呢&#xff1f;所以&#xff0c;挑IP绝对是个技…...

招商蛇口 | 执笔CID,启幕低密生活新境

作为中国城市生长的力量&#xff0c;招商蛇口以“美好生活承载者”为使命&#xff0c;深耕全球111座城市&#xff0c;以央企担当匠造时代理想人居。从深圳湾的开拓基因到西安高新CID的战略落子&#xff0c;招商蛇口始终与城市发展同频共振&#xff0c;以建筑诠释对土地与生活的…...

[大语言模型]在个人电脑上部署ollama 并进行管理,最后配置AI程序开发助手.

ollama官网: 下载 https://ollama.com/ 安装 查看可以使用的模型 https://ollama.com/search 例如 https://ollama.com/library/deepseek-r1/tags # deepseek-r1:7bollama pull deepseek-r1:7b改token数量为409622 16384 ollama命令说明 ollama serve #&#xff1a…...