pytorch学习(7)——神经网络优化器torch.optim
1 optim 优化器
PyTorch神经网络优化器(optimizer)通过调整神经网络的参数(weight和bias)来最小化损失函数(Loss)。
学习链接:
https://pytorch.org/docs/stable/optim.html

1.1 优化器基类
使用时必须构造一个优化器对象,它将保存当前状态,并将根据计算的梯度(grad)更新参数。
调用优化器的step方法。
CLASS torch.optim.Optimizer(params, defaults)
- Optimizer - 优化器的优化算法。
- params (iterable) – torch的迭代器。张量s或dict s,指定应该优化什么张量。
- defaults – (dict): 包含优化选项默认值的字典(在参数组没有指定优化选项时使用)。每个Optimizer算法都有其独特的设置字典。
| 算法(Optimizer) | 说明 |
|---|---|
| Adadelta | 采用Adadelta算法。 |
| Adagrad | 采用Adagrad算法。 |
| Adam | 采用Adam算法。 |
| AdamW | 采用AdamW算法。 |
| SparseAdam | 采用适合稀疏张量的Adam算法的惰性版本。 |
| Adamax | 采用Adamax算法(Adam基于无穷范数的变种)。 |
| ASGD | 采用平均随机梯度下降。 |
| LBFGS | 采用L-BFGS算法,深受minFunc的启发。 |
| NAdam | 采用NAdam 算法。 |
| RAdam | 采用RAdam 算法。 |
| RMSprop | 采用RMSprop 算法。 |
| Rprop | 采用有弹性的反向传播算法。 |
| SGD | 采用随机梯度下降算法。 |
1.1.1 SGD 随机梯度下降算法
CLASS torch.optim.SGD(params, lr=<required parameter>, momentum=0, dampening=0, weight_decay=0, nesterov=False, *, maximize=False, foreach=None, differentiable=False)
- params (iterable) – iterable参数优化或字典定义参数组。
- lr (float) – 学习率,需要用户输入。
- momentum (float, optional) – 动量系数(默认值为0)。
- weight_decay (float, optional) – 权重衰减(L2惩罚) (默认值为0)
。 - dampening (float, optional) – 动量阻尼(默认值为0)。
- nesterov (bool, optional) – 使能Nesterov动量(默认值为
False)。
【Nesterov动量(Nesterov Momentum)是一种基于动量法的优化算法,用于加速神经网络的训练过程。它在随机梯度下降(SGD)的基础上进行改进,通过考虑参数更新前的动量信息来调整参数更新的方向。】 - maximize (bool, optional) – 根据目标最大化参数,而不是最小化参数(默认值为
False)。 - foreach (bool, optional) – 是否使用foreach优化器的实现。如果用户未指定(foreach为
None),我们将尝试在CUDA上的for循环实现上使用foreach,因为CUDA通常性能更高(默认值:None)。 - differentiable (bool, optional) – 是否在训练中的优化器步骤中发生autograd。否则,step()函数在torch.no_grad()上下文中运行。设置为
True会影响性能,所以如果你不打算通过这个实例运行autograd,请保留False(默认值为False)。
学习速率(lr)的取值,如果太大,则模型很不稳定;如果太小,学习速度非常缓慢。因此一般先设置较大的学习速率,然后降低学习速率。
python代码如下:
import torchvision
import torch
from torch import nn, optim
from torch.nn import Linear, Conv2d, MaxPool2d, Flatten, Sequential,CrossEntropyLoss
from torch.utils.data import DataLoader
from torch.utils.tensorboard import SummaryWriterdataset = torchvision.datasets.CIFAR10(root="G:\\Anaconda\\pycharm_pytorch\\learning_project\\dataset_CIFAR10",train=False,transform=torchvision.transforms.ToTensor(),download=False)dataloader = DataLoader(dataset, batch_size=1)class MYNN(nn.Module):def __init__(self):super(MYNN, self).__init__()self.model1 = Sequential(Conv2d(3, 32, 5, padding=2, stride=1),MaxPool2d(2),Conv2d(32, 32, 5, padding=2, stride=1),MaxPool2d(2),Conv2d(32, 64, 5, padding=2, stride=1),MaxPool2d(2),Flatten(),Linear(1024, 64),Linear(64, 10))def forward(self, x):x = self.model1(x)return xloss = CrossEntropyLoss()
mynn = MYNN()
opitm = optim.SGD(mynn.parameters(), lr=0.01) # 优化器for data in dataloader:imgs, targets = dataoutputs = mynn(imgs)result_loss = loss(outputs, targets)# print(outputs) # 神经网络输出# print(targets) # 目标# print(result_loss) # 损失函数-交叉熵计算结果opitm.zero_grad() # 梯度清零,设置断点result_loss.backward() # 反向传播,求出每个节点的梯度,设置断点opitm.step() # 对神经网络模型的参数进行调优,设置断点
设置断点,进入程序Debug:
(1)不断运行程序,能够观察到卷积层0的bias梯度变化:mynn -> Protected Attributes -> _modules -> ‘model1’ -> Protected Attributes -> _modules -> ‘0’ -> bias-> grad

(2)能观察到卷积层0的weight梯度变化:mynn -> Protected Attributes -> _modules -> ‘model1’ -> Protected Attributes -> _modules -> ‘0’ -> weight -> grad

(3)能观察到bias的变化:mynn -> Protected Attributes -> _modules -> ‘model1’ -> Protected Attributes -> _modules -> ‘0’ -> bias-> data

(4)能观察到weight的变化:mynn -> Protected Attributes -> _modules -> ‘model1’ -> Protected Attributes -> _modules -> ‘0’ -> weight-> data

(5)结论:运行opitm.zero_grad()后,清空weight和bias的梯度grad;运行result_loss.backward()后,计算得到新的weight和bias的梯度grad;运行opitm.step()后,调整weight和bias的值。
1.1.2 优化器多次循环
修改以上python代码,增加多次循环,观察总体损失值改变。
import torchvision
import torch
from torch import nn, optim
from torch.nn import Linear, Conv2d, MaxPool2d, Flatten, Sequential,CrossEntropyLoss
from torch.utils.data import DataLoader
from torch.utils.tensorboard import SummaryWriterdataset = torchvision.datasets.CIFAR10(root="G:\\Anaconda\\pycharm_pytorch\\learning_project\\dataset_CIFAR10",train=False,transform=torchvision.transforms.ToTensor(),download=False)dataloader = DataLoader(dataset, batch_size=1)class MYNN(nn.Module):def __init__(self):super(MYNN, self).__init__()self.model1 = Sequential(Conv2d(3, 32, 5, padding=2, stride=1),MaxPool2d(2),Conv2d(32, 32, 5, padding=2, stride=1),MaxPool2d(2),Conv2d(32, 64, 5, padding=2, stride=1),MaxPool2d(2),Flatten(),Linear(1024, 64),Linear(64, 10))def forward(self, x):x = self.model1(x)return xloss = CrossEntropyLoss()
mynn = MYNN()
opitm = optim.SGD(mynn.parameters(), lr=0.01) # 优化器for epoch in range(20):running_loss = 0.0for data in dataloader:imgs, targets = dataoutputs = mynn(imgs)result_loss = loss(outputs, targets)# print(outputs) # 神经网络输出# print(targets) # 目标# print(result_loss) # 损失函数-交叉熵计算结果opitm.zero_grad() # 梯度清零result_loss.backward() # 反向传播,求出每个节点的梯度opitm.step() # 对神经网络模型的参数进行调优running_loss = running_loss + result_loss#.dataprint(running_loss)
运行结果:
tensor(18746.2012, grad_fn=<AddBackward0>)
tensor(16136.0107, grad_fn=<AddBackward0>)
tensor(15499.3203, grad_fn=<AddBackward0>)
tensor(nan, grad_fn=<AddBackward0>)
tensor(nan, grad_fn=<AddBackward0>)
tensor(nan, grad_fn=<AddBackward0>)
可以发现running_loss在一开始不断降低,但是以下的nan暂时不知道是什么原因。
相关文章:
pytorch学习(7)——神经网络优化器torch.optim
1 optim 优化器 PyTorch神经网络优化器(optimizer)通过调整神经网络的参数(weight和bias)来最小化损失函数(Loss)。 学习链接: https://pytorch.org/docs/stable/optim.html 1.1 优化器基类 使…...
leetcode做题笔记101. 对称二叉树
给你一个二叉树的根节点 root , 检查它是否轴对称。 思路一:递归 bool isSymmetric(struct TreeNode* root){if (root NULL) return true;return fun(root->left, root->right); }int fun(struct TreeNode* l_root, struct TreeNode* r_root) {…...
边缘计算相关概念--学习笔记
一.边缘计算概念 边缘计算将数据的处理,应用程序的运行甚至一些功能服务的实现,由网络中心下放到网络边缘的节点上,在网络边缘侧的智能网关上就近采集并且处理数据,不需要将大量未处理的数据上传到远程的大数据平台。边缘计算理论…...
flutter windows编译错误 flutter_assemble.vcxproj
flutter 编译windows是出现错误。 [ 44 ms] d:\Program Files\Microsoft Visual Studio\2022\Community\MSBuild\Microsoft\VC\v170\Microsoft.CppCommon.targets(248,5): error MSB8066: ��E:\work\kkview_kuaichuan\kkview_kuaichuan\build\windows\C…...
通过运行中的容器生成 Docker Compose 配置文件
背景 笔者之前有一次不小心删除了原始的 docker-compose.yml 文件,不过正在运行的 Docker 容器还在,找了许久,发现一个方法可以从这些容器中生成一个等效的 Docker Compose 配置文件。本文将介绍使用 autocompose 工具从正在运行的容器中反向…...
rancher界面无法登陆问题解决,登录超时;
1.找到rancher主机,查看日志 docker ps | grep rancher # rancher 容器 名称 jolly_ptolemy docker logs -f jolly_ptolemy 日志提示, java.sql.SQLException: Got error 28 from storage engine,磁盘满了 2.磁盘管理 df -h #查看磁盘使…...
Django(6)-django项目自动化测试
Django 应用的测试应该写在应用的 tests.py 文件里。测试系统会自动的在所有以 tests 开头的文件里寻找并执行测试代码。 我们的 polls 应用现在有一个小 bug 需要被修复:我们的要求是如果 Question 是在一天之内发布的, Question.was_published_recentl…...
【AUTOSAR】【CAN通信】CanNm
目录 一、概述 二、说明 三、功能说明 3.1 协调算法 3.2 操作模式 3.2.1 网络模式...
拼多多淘宝大量缓存商品数据用什么格式提供比较好?
众所周知,淘宝拼多多是我国主流的电商平台,其上有大量的商品数据。很多商家会通过API来访问他们的商品数据,根据API的调用次数收费。第三方数据公司提供电商数据接口API,采集实时数据。但是,在他们的服务器上有大量的缓…...
【校招VIP】前端校招考点之页面转换算法
考点介绍: 在地址映射过程中,若在页面中发现所要访问的页面不在内存中,则产生缺页中断。当发生缺页中断时,如果操作系统内存中没有空闲页面,则操作系统必须在内存选择一个页面将其移出内存,以便为即将调入的…...
android 下载网络文件
工具类 import android.app.ProgressDialog; import android.content.Context; import android.os.AsyncTask; import android.os.Environment; import android.util.Log;import java.io.BufferedInputStream; import java.io.File; import java.io.FileOutputStream; import …...
springboot定时任务:同时使用定时任务和websocket报错
背景 项目使用了websocket,实现了消息的实时推送。后来项目需要一个定时任务,使用org.springframework.scheduling.annotation的EnableScheduling注解来实现,启动项目之后报错 Bean com.alibaba.cloud.sentinel.custom.SentinelAutoConfiguration of t…...
CSS3渐变及2D转换
CSS3渐变及2D转换 持续更新哦… 1、css3渐变 概念: CSS3渐变(gradient)可以让你在两个或多个指定的颜色之间显示平 稳的过渡。以前,你必须使用图像来实现这些效果,现在通过使用 CSS3的渐变(gradients)即可实现。此外,渐变效果的元素在放大…...
无涯教程-PHP - eregi()函数
eregi() - 语法 int eregi(string pattern, string string, [array regs]); eregi()函数在pattern指定的整个字符串中搜索string指定的字符串,。搜索不区分大小写。 Eregi()在检查字符串的有效性时特别有用。 可选的输入参数regs包含一个由正则表达式中的括号分组的所有匹配…...
Spring与Mybatis整合aop整合pageHelper分页插件
前言 Spring与MyBatis整合的意义在于提供了一种结合优势的方式,以便更好地开发和管理持久层(数据库访问)代码。 这里也是总结了几点主要意义 简化配置:Spring与MyBatis整合后,可以通过Spring的配置文件来管理和配置M…...
SSL/CA 证书及其相关证书文件(pem、crt、cer、key、csr)
数字证书是网络世界中的身份证,数字证书为实现双方安全通信提供了电子认证。数字证书中含有密钥对所有者的识别信息,通过验证识别信息的真伪实现对证书持有者身份的认证。数字证书可以在网络世界中为互不见面的用户建立安全可靠的信任关系,这…...
【JavaSE】内部类
文章目录 内部类概念局部内部类匿名内部类(重点重点!!! )成员内部类静态内部类 内部类概念 可以将一个类定义在另一个类或者一个方法的内部,前者称为内部类,后者称为外部类。内部类也是封装的一…...
Django(2)-编写你的第一个 Django 应用
本教程的目的是创建一个网络投票应用程序。 它将由两部分组成: 一个让人们查看和投票的公共站点。 一个让你能添加、修改和删除投票的管理站点。 创建应用 $ python manage.py startapp polls每一个应用是一个python包,一个项目可以包含多个应用。 …...
燃气管网监测系统,24小时守护燃气安全
随着社会的发展和人民生活水平的提高,燃气逐渐成为人们日常生活和工作中不可或缺的一部分。然而,近年来,屡屡发生的燃气爆炸问题,也让人们不禁对燃气的安全性产生了担忧。因此,建立一个高效、实时、准确的燃气管网监测…...
昌硕科技、世硕电子同步上线法大大电子合同
近日,世界500强企业和硕联合旗下上海昌硕科技有限公司(以下简称“昌硕科技”)、世硕电子(昆山)有限公司(以下简称“世硕电子”)的电子签项目正式上线。上线仪式在上海浦东和硕集团科研大楼举行&…...
(LeetCode 每日一题) 3442. 奇偶频次间的最大差值 I (哈希、字符串)
题目:3442. 奇偶频次间的最大差值 I 思路 :哈希,时间复杂度0(n)。 用哈希表来记录每个字符串中字符的分布情况,哈希表这里用数组即可实现。 C版本: class Solution { public:int maxDifference(string s) {int a[26]…...
Ubuntu系统下交叉编译openssl
一、参考资料 OpenSSL&&libcurl库的交叉编译 - hesetone - 博客园 二、准备工作 1. 编译环境 宿主机:Ubuntu 20.04.6 LTSHost:ARM32位交叉编译器:arm-linux-gnueabihf-gcc-11.1.0 2. 设置交叉编译工具链 在交叉编译之前&#x…...
黑马Mybatis
Mybatis 表现层:页面展示 业务层:逻辑处理 持久层:持久数据化保存 在这里插入图片描述 Mybatis快速入门 
下载HBuilderX 访问官方网站:https://www.dcloud.io/hbuilderx.html 根据您的操作系统选择合适版本: Windows版(推荐下载标准版) Windows系统安装步骤 运行安装程序: 双击下载的.exe安装文件 如果出现安全提示&…...
论文解读:交大港大上海AI Lab开源论文 | 宇树机器人多姿态起立控制强化学习框架(一)
宇树机器人多姿态起立控制强化学习框架论文解析 论文解读:交大&港大&上海AI Lab开源论文 | 宇树机器人多姿态起立控制强化学习框架(一) 论文解读:交大&港大&上海AI Lab开源论文 | 宇树机器人多姿态起立控制强化…...
CMake 从 GitHub 下载第三方库并使用
有时我们希望直接使用 GitHub 上的开源库,而不想手动下载、编译和安装。 可以利用 CMake 提供的 FetchContent 模块来实现自动下载、构建和链接第三方库。 FetchContent 命令官方文档✅ 示例代码 我们将以 fmt 这个流行的格式化库为例,演示如何: 使用 FetchContent 从 GitH…...
全志A40i android7.1 调试信息打印串口由uart0改为uart3
一,概述 1. 目的 将调试信息打印串口由uart0改为uart3。 2. 版本信息 Uboot版本:2014.07; Kernel版本:Linux-3.10; 二,Uboot 1. sys_config.fex改动 使能uart3(TX:PH00 RX:PH01),并让boo…...
