【动手学深度学习Pytorch】2. Softmax回归代码
零实现
导入所需要的包:
import torch
from IPython import display
from d2l import torch as d2l
定义数据集参数、模型参数:
batch_size = 256 # 每次随机读取256张图片
train_iter, test_iter = d2l.load_data_fashion_mnist(batch_size)
# 将展平每个图片将其视为长度为784的向量,数据集存在10个类别
num_inputs = 784
num_outputs = 10
W = torch.normal(0, 0.01, size=(num_inputs, num_outputs), requires_grad=True)
b = torch.zeros(num_outputs, requires_grad=True)
实现Softmax操作:
# 实现Softmax
def softmax(X):X_exp = torch.exp(X)partition = X_exp.sum(1, keepdim=True) #列数为特征数,行数为样本数return X_exp / partition #广播机制# 尝试进行Softmax操作
X = torch.normal(0, 1, (2,5))
X_prob = softmax(X)
X_prob, X_prob.sum(1)# 实现Softmax回归模型
def net(X):return softmax(torch.matmul(X.reshape(-1,W.shape[0]),W)+b)
定义交叉熵函数:
# 创建一个数据y_hat,其中包含2个样本在3个类别的预测概率,使用y作为y_hat中概率的索引
y = torch.tensor([0,2])
y_hat = torch.tensor([[0.1, 0.3, 0.6],[0.3, 0.2, 0.5]])
y_hat[[0, 1], y]
# 交叉熵函数
def cross_entropy(y_hat, y):return -torch.log(y_hat[range(len(y_hat)),y])
cross_entropy(y_hat, y)
将预测类别于真实元素进行比较:
torch.argmax(input, dim=None, keepdim=False):用于返回指定维度中最大值的索引。通常用于分类任务中从预测输出中找到概率最大的类别
.dtype:
.dtype
是张量的属性,用于返回该张量的 数据类型 (data type)。每个张量都有一个数据类型,用于定义其中存储元素的类型,例如浮点数、整数或布尔值。
tensor.type(dtype=None):不传入参数时,返回一个字符串,表示张量的类型;传入参数时,返回一个新的张量,该张量的类型与指定类型匹配。
x = torch.tensor([1.0, 2.0, 3.0]) # 默认 float32 类型 print(x.type()) # 输出: torch.FloatTensorx_int = x.type(torch.int64) print(x_int) # 输出: tensor([1, 2, 3]) print(x_int.type()) # 输出: torch.LongTensor (int64 的别名)
net.eval():设置为评估模式。
def accuracy(y_hat, y):#计算预测争取的数量# 判断 y_hat 是否为多维张量(例如二维)if len(y_hat.shape)>1 and y_hat.shape[1] > 1:# 如果是多类别分类(第二维大于 1),通过argmax获取每行中概率或分数最大的类别索引y_hat = y_hat.argmax(axis=1)cmp = y_hat.type(y.dtype)==y # 比较预测结果和真实标签是否相等return float(cmp.type(y.dtype).sum()) # 返回预测正确的总数量accuracy(y_hat, y) / len(y)def evaluate_accuracy(net, data_iter):#计算在指定数据集上的模型精度# 如果是 PyTorch 模型,设置为评估模式if isinstance(net, torch.nn.Module):net.eval() metric = Accumulator(2) # 初始化累加器,存储 [正确预测数, 总样本数]for X, y in data_iter:metric.add(accuracy(net(X), y), y.numel()) # 累加每批数据的预测结果return metric[0] / metric[1] # 返回精度:正确预测数 / 总样本数
Accumulator实例:
class Accumulator: #在n个变量上累加def __init__(self, n):self.data = [0.0] * ndef add(self, *args):self.data = [a + float(b) for a, b in zip(self.data, args)]def reset(self):self.data = [0.0] * len(self.data)def __getitem__(self, idx):return self.data[idx]evaluate_accuracy(net, test_iter)
定义训练过程:
net.train():设置为训练模式。
torch.optim.Optimizer.step():用于执行模型参数更新。基于之前计算好的梯度(通过反向传播获得),按照优化算法的规则调整模型参数的值,以最小化损失函数。
def train_epoch_ch3(net, train_iter, loss, updater):if isinstance(net, torch.nn.Module):net.train()metric = Accumulator(3)for X, y in train_iter:y_hat = net(X)l = loss(y_hat, y) #计算损失if isinstance(updater, torch.optim.Optimizer):updater.zero_grad() # 清除梯度l.backward() # 反向传播计算梯度updater.step() # 根据梯度更新模型参数metric.add(float(l) * len(y), # 累加当前批次的损失accuracy(y_hat, y), # 累加当前批次的正确预测数y.size().numel()) # 累加当前批次的样本数else: # 如果是自定义优化器l.sum().backward()updater(X.shape[0]) # 自定义的更新函数,可能需要批次大小作为参数metric.add(float(l.sum()), accuracy(y_hat),y.numel())return metric[0] / metric[2], metric[1] / metric[2]
定义一个在动画中绘制数据的实用程序类:
class Animator: #实时观看在训练过程中的变化# 初始化绘图环境,包括图表的设置、标签、坐标轴范围、曲线样式等。def __init__(self, xlabel=None, ylabel=None, legend=None, xlim=None,ylim=None, xscale='linear', yscale='linear',fmts=('-','m--','g-','r:'),nrows=1,ncols=1,figsize=(3.5, 2,5)):if legend is None:legend = []d2l.use_svg_display()self.fig, self.axes = plt.subplots(nrows, ncols, figsize=figsize)if nrows * ncols ==1:self.axes = [self.axes,]self.config_axes = lambda:d2l.set_axes(self.axes[0],xlabel, ylabel,xlim, ylim,xscale, yscale,legend)self.X, self.Y, self.fmt = None, None, fmtsdef add(self, x, y):if not hasattr(y, "__len__"):y = [y]n = len(y)
训练函数:
# 训练函数
def train_ch3(net, train_iter, test_iter, loss, num_epochs, updater):# 进行可视化animator = Aminator(xlabel='epoch', xlim=[1, num_epochs], ylim=[0.3,],legend=['train loss','train acc','test acc'])for epoch in range(num_epochs):train_metrics = train_epoch_ch2(net, train_iter, loss, updater)test_acc = evaluate_accuracy(net, test_iter)animator.add(epoch+1, train_metrics+(test_acc,))train_loss, train_acc = train_metrics# 小批量随机梯度下降来优化训练算法
lr = 0.1
def updater(batch_size):return d2l.sgd([W,b],lr,batch_size)num_epochs = 10
train_ch3(net, train_iter, test_iter, cross_entropy, num_epochs, updater(10))
简洁实现
导入所需要的包:
import torch
from IPython import display
from d2l import torch as d2l
初始化数据集、模型参数、损失函数以及训练优化算法:网络加入高斯噪声,增强泛化性。
torch.nn.init.normal_(tensor, mean=0.0, std=1.0):正态分布(高斯分布)随机初始化张量的值
nn.Sequential(*modules):用于将多个模块(如线性层、激活函数等)按顺序组合成一个模型。适合简单的前向计算场景。
nn.Flatten(start_dim=1, end_dim=-1):将输入张量展平成二维张量,适用于线性层输入。
nn.Linear(in_features, out_features, bias=True):实现一个线性层(全连接层)
nn.CrossEntropyLoss(weight=None, ignore_index=-100, reduction='mean'):计算分类任务中的交叉熵损失(适用于多分类问题)。
torch.optim.SGD(params, lr=<required parameter>, momentum=0, dampening=0, weight_decay=0, nesterov=False):实现随机梯度下降(SGD)优化算法,用于更新模型参数。net.parameters():返回模型的可训练参数的迭代器。
batch_size = 256 # 每次随机读取256张图片
train_iter, test_iter = d2l.load_data_fashion_mnist(batch_size)net = nn.Sequential(nn.Flatten(),nn.Linear(784, 100))
def init_weights(m):if type(m) == nn.Linear:nn.init.normal_(m.weight, std=0.01)net.apply(init_weights);loss = nn.CrossEntropyLoss()trainer = torch.optim.SGD(net.parameters(),lr=0.1)
用之前定义的训练函数训练模型:
num_epochs = 10
train_ch3(net, train_iter, test_iter, cross_entropy, num_epochs, updater(10))
相关文章:

【动手学深度学习Pytorch】2. Softmax回归代码
零实现 导入所需要的包: import torch from IPython import display from d2l import torch as d2l定义数据集参数、模型参数: batch_size 256 # 每次随机读取256张图片 train_iter, test_iter d2l.load_data_fashion_mnist(batch_size) # 将展平每个…...
技术周总结 11.11~11.17 周日(Js JVM XML)
文章目录 一、11.11 周一1.1)问题01:js中的prompt弹窗区分出来用户点击的是 确认还是取消进一步示例 1.2)问题02:在 prompt弹窗弹出时默认给弹窗中写入一些内容 二、11.12 周二2.1) 问题02: 详解JVM中的本地方法栈本地方法栈的主要…...

MATLAB 使用教程 —— 矩阵和数组
矩阵和数组MATLAB 中矩阵和数组长什么样?MATLAB 怎么用矩阵计算?创建和操作矩阵矩阵运算示例串联 访问矩阵的元素 矩阵和数组 MATLAB 是“matrix laboratory”的缩写形式。MATLAB 主要用于处理 整个的矩阵和数组,而其他编程语言大多逐个处理…...

React教程第二节之虚拟DOM与Diffing算法理解
1、什么是虚拟DOM 虚拟DOM 是javascript的一个对象,是内存中的一种数据结构,以树的形式存储UI的状态,树中的每个节点都代表着真实的DOM,用来描述我们希望在页面看到的 HTML结构; 现在的MVVM 框架,大多使用…...

C++——类和对象(part2)
前言 本篇博客继续为大家介绍类与对象的知识,承接part1的内容,本篇内容是类与对象的核心内容,稍微有些复杂,如果你对其感兴趣,请继续阅读,下面进入正文部分。 1. 类的默认成员函数 默认成员函数就是用户…...
【FFmpeg系列】:音频处理
前言 在多媒体处理领域,FFmpeg无疑是一个不可或缺的利器。它功能强大且高度灵活,能够轻松应对各种音频和视频处理任务,无论是简单的格式转换,还是复杂的音频编辑,都不在话下。然而,要想真正发挥FFmpeg的潜…...

Python绘制雪花
文章目录 系列目录写在前面技术需求完整代码代码分析1. 代码初始化部分分析2. 雪花绘制核心逻辑分析3. 窗口保持部分分析4. 美学与几何特点总结 写在后面 系列目录 序号直达链接爱心系列1Python制作一个无法拒绝的表白界面2Python满屏飘字表白代码3Python无限弹窗满屏表白代码4…...

vue3 如何调用第三方npm包内部的 pinia 状态管理库方法
抛砖引玉: 如果在开发vue3项目是, 引用了npm第三方包 ,而且这个包内使用了Pinia 状态管理库,那我们如何去调用 npm内部的 Pinia 状态管理库呢? 实际遇到的问题: 今天在制作npm包时遇到的问题,之前Vue2版本的时候状态管理库用的Vuex ,当时调用npm包内的状态管理库很简单,直接引…...
uni-app快速入门(七)--组件路由跳转和API路由跳转及参数传递
uni-app有两种页面路由跳转模式,即使用navigator组件跳转和调用API跳转,API调转不要理解为调用后台接口的API,而是指脚本函数中使用跳转函数。 一、组件路由跳转 1.1 打开新页面 打开新页面使用组件的open-type"navigate",见下面…...
Flink升级程序和版本
Flink DataStream程序通常设计为长时间运行,如几周、几个月甚至几年。与所有长时间运行的服务一样,Flink streaming应用程序也需要维护,包括修复错误、实现改进或将应用程序迁移到更高版本的Flink集群。 这里就来描述下如何更新Flink streaming应用程序,以及如何将正在运行…...
从0安装mysql server
安装 MySQL Server 首先,你需要在 Ubuntu 上安装 MySQL 服务器。运行以下命令来安装:sudo apt update sudo apt install mysql-server安装完成后,MySQL 服务会自动启动。你可以通过以下命令检查 MySQL 服务是否正在运行: sudo systemctl status mysql如果 MySQL 正在运行,…...

web安全测试渗透案例知识点总结(上)——小白入狱
目录 一、Web安全渗透测试概念详解1. Web安全与渗透测试2. Web安全的主要攻击面与漏洞类型3. 渗透测试的基本流程 二、知识点详细总结1. 常见Web漏洞分析2. 渗透测试常用工具及其功能 三、具体案例教程案例1:SQL注入漏洞利用教程案例2:跨站脚本ÿ…...
PHP访问NetSuite REST Web Services
“同等看待欢乐和痛苦、得到和失去、胜利和失败、投入战斗。以此方式履行职责,你就不会招致任何罪恶。” -Bhagavad Gita 为了帮助PHP开发者快速起步,以REST Web Services方式打通与NetSuite的接口,我们答应给一个样例。但是我是不懂PHP的&a…...

【编译】多图解释 什么是短语、直接短语、句柄、素短语、可归约串
一、什么是短语二、什么是“直接”短语?三、什么是句柄?四、什么是素短语?五、什么是最左素短语可归约串就是“最左素短语” 首先,这些概念 都是相对于【句型】的,都是相对于【句型】的,都是相对于【句型】…...
React中事件绑定和Vue有什么区别?
1. 绑定方式 React:使用jsx语法,通过属性绑定事件。Vue:使用指令(如v-on)在模板中直接绑定事件。 2. 事件处理 React:通过合成事件系统封装原生事件,提供统一的API。Vue:直接使用…...
【DBA攻坚指南:左右Oracle,右手MySQL-学习总结】
处理log file sync等待事件 首先明确什么是log file sync等待事件 从用户提交会话开始,LGWR进程将redo缓存中的信息写入redo日志文件后,LGWR进程通知用户写操作完成,到用户会话接受到LGWR进程通知为止,这整个过程就是可能出现lo…...
C++中的内联函数
在C中,内联函数是一种特殊的函数。 定义 内联函数是在函数定义前加上关键字“inline”的函数。编译器在处理对内联函数的调用时,会尝试将函数体的代码直接插入到函数调用处,而不是像普通函数调用那样,进行跳转指令执行函数体代码…...
ssh.service could not be found“
如果你收到 “ssh.service could not be found” 错误,说明目标主机上没有安装 SSH 服务,或者安装的 SSH 服务的名称不为 ssh。这里有一些解决步骤: 1. 检查 SSH 服务是否已安装 在目标主机上执行以下命令来检查是否安装了 SSH 服务&#x…...
tensorflow有哪些具体影响,和chatgpt有什么关系
### TensorFlow的影响 **1. 深度学习框架的领军者** - **广泛使用**: TensorFlow是由Google开发的开源深度学习框架,广泛应用于各种机器学习任务,包括图像识别、自然语言处理、语音识别等。它是深度学习领域中最受欢迎的框架之一。 - **大规模生产环境*…...

Android OpenGL ES详解——几何着色器
目录 一、概念 1、图元 2、几何着色器 1、输入类型 2、输出类型 3、输出顶点数量最大值限制 二、使用几何着色器 三、应用举例——造几个房子 四、应用举例——爆破物体 1、获取法向量 2、显示法线 五、应用举例——细分三角形 六、应用举例——广告牌技术 一、概…...

利用最小二乘法找圆心和半径
#include <iostream> #include <vector> #include <cmath> #include <Eigen/Dense> // 需安装Eigen库用于矩阵运算 // 定义点结构 struct Point { double x, y; Point(double x_, double y_) : x(x_), y(y_) {} }; // 最小二乘法求圆心和半径 …...

MPNet:旋转机械轻量化故障诊断模型详解python代码复现
目录 一、问题背景与挑战 二、MPNet核心架构 2.1 多分支特征融合模块(MBFM) 2.2 残差注意力金字塔模块(RAPM) 2.2.1 空间金字塔注意力(SPA) 2.2.2 金字塔残差块(PRBlock) 2.3 分类器设计 三、关键技术突破 3.1 多尺度特征融合 3.2 轻量化设计策略 3.3 抗噪声…...
设计模式和设计原则回顾
设计模式和设计原则回顾 23种设计模式是设计原则的完美体现,设计原则设计原则是设计模式的理论基石, 设计模式 在经典的设计模式分类中(如《设计模式:可复用面向对象软件的基础》一书中),总共有23种设计模式,分为三大类: 一、创建型模式(5种) 1. 单例模式(Sing…...

《Qt C++ 与 OpenCV:解锁视频播放程序设计的奥秘》
引言:探索视频播放程序设计之旅 在当今数字化时代,多媒体应用已渗透到我们生活的方方面面,从日常的视频娱乐到专业的视频监控、视频会议系统,视频播放程序作为多媒体应用的核心组成部分,扮演着至关重要的角色。无论是在个人电脑、移动设备还是智能电视等平台上,用户都期望…...
【Java学习笔记】Arrays类
Arrays 类 1. 导入包:import java.util.Arrays 2. 常用方法一览表 方法描述Arrays.toString()返回数组的字符串形式Arrays.sort()排序(自然排序和定制排序)Arrays.binarySearch()通过二分搜索法进行查找(前提:数组是…...
深入浅出:JavaScript 中的 `window.crypto.getRandomValues()` 方法
深入浅出:JavaScript 中的 window.crypto.getRandomValues() 方法 在现代 Web 开发中,随机数的生成看似简单,却隐藏着许多玄机。无论是生成密码、加密密钥,还是创建安全令牌,随机数的质量直接关系到系统的安全性。Jav…...

【JVM】- 内存结构
引言 JVM:Java Virtual Machine 定义:Java虚拟机,Java二进制字节码的运行环境好处: 一次编写,到处运行自动内存管理,垃圾回收的功能数组下标越界检查(会抛异常,不会覆盖到其他代码…...
C# SqlSugar:依赖注入与仓储模式实践
C# SqlSugar:依赖注入与仓储模式实践 在 C# 的应用开发中,数据库操作是必不可少的环节。为了让数据访问层更加简洁、高效且易于维护,许多开发者会选择成熟的 ORM(对象关系映射)框架,SqlSugar 就是其中备受…...
全面解析各类VPN技术:GRE、IPsec、L2TP、SSL与MPLS VPN对比
目录 引言 VPN技术概述 GRE VPN 3.1 GRE封装结构 3.2 GRE的应用场景 GRE over IPsec 4.1 GRE over IPsec封装结构 4.2 为什么使用GRE over IPsec? IPsec VPN 5.1 IPsec传输模式(Transport Mode) 5.2 IPsec隧道模式(Tunne…...

【数据分析】R版IntelliGenes用于生物标志物发现的可解释机器学习
禁止商业或二改转载,仅供自学使用,侵权必究,如需截取部分内容请后台联系作者! 文章目录 介绍流程步骤1. 输入数据2. 特征选择3. 模型训练4. I-Genes 评分计算5. 输出结果 IntelliGenesR 安装包1. 特征选择2. 模型训练和评估3. I-Genes 评分计…...