pytorch学习(7)——神经网络优化器torch.optim
1 optim 优化器
PyTorch神经网络优化器(optimizer)通过调整神经网络的参数(weight和bias)来最小化损失函数(Loss)。
学习链接:
https://pytorch.org/docs/stable/optim.html

1.1 优化器基类
使用时必须构造一个优化器对象,它将保存当前状态,并将根据计算的梯度(grad)更新参数。
调用优化器的step方法。
CLASS torch.optim.Optimizer(params, defaults)
- Optimizer - 优化器的优化算法。
- params (iterable) – torch的迭代器。张量s或dict s,指定应该优化什么张量。
- defaults – (dict): 包含优化选项默认值的字典(在参数组没有指定优化选项时使用)。每个Optimizer算法都有其独特的设置字典。
| 算法(Optimizer) | 说明 |
|---|---|
| Adadelta | 采用Adadelta算法。 |
| Adagrad | 采用Adagrad算法。 |
| Adam | 采用Adam算法。 |
| AdamW | 采用AdamW算法。 |
| SparseAdam | 采用适合稀疏张量的Adam算法的惰性版本。 |
| Adamax | 采用Adamax算法(Adam基于无穷范数的变种)。 |
| ASGD | 采用平均随机梯度下降。 |
| LBFGS | 采用L-BFGS算法,深受minFunc的启发。 |
| NAdam | 采用NAdam 算法。 |
| RAdam | 采用RAdam 算法。 |
| RMSprop | 采用RMSprop 算法。 |
| Rprop | 采用有弹性的反向传播算法。 |
| SGD | 采用随机梯度下降算法。 |
1.1.1 SGD 随机梯度下降算法
CLASS torch.optim.SGD(params, lr=<required parameter>, momentum=0, dampening=0, weight_decay=0, nesterov=False, *, maximize=False, foreach=None, differentiable=False)
- params (iterable) – iterable参数优化或字典定义参数组。
- lr (float) – 学习率,需要用户输入。
- momentum (float, optional) – 动量系数(默认值为0)。
- weight_decay (float, optional) – 权重衰减(L2惩罚) (默认值为0)
。 - dampening (float, optional) – 动量阻尼(默认值为0)。
- nesterov (bool, optional) – 使能Nesterov动量(默认值为
False)。
【Nesterov动量(Nesterov Momentum)是一种基于动量法的优化算法,用于加速神经网络的训练过程。它在随机梯度下降(SGD)的基础上进行改进,通过考虑参数更新前的动量信息来调整参数更新的方向。】 - maximize (bool, optional) – 根据目标最大化参数,而不是最小化参数(默认值为
False)。 - foreach (bool, optional) – 是否使用foreach优化器的实现。如果用户未指定(foreach为
None),我们将尝试在CUDA上的for循环实现上使用foreach,因为CUDA通常性能更高(默认值:None)。 - differentiable (bool, optional) – 是否在训练中的优化器步骤中发生autograd。否则,step()函数在torch.no_grad()上下文中运行。设置为
True会影响性能,所以如果你不打算通过这个实例运行autograd,请保留False(默认值为False)。
学习速率(lr)的取值,如果太大,则模型很不稳定;如果太小,学习速度非常缓慢。因此一般先设置较大的学习速率,然后降低学习速率。
python代码如下:
import torchvision
import torch
from torch import nn, optim
from torch.nn import Linear, Conv2d, MaxPool2d, Flatten, Sequential,CrossEntropyLoss
from torch.utils.data import DataLoader
from torch.utils.tensorboard import SummaryWriterdataset = torchvision.datasets.CIFAR10(root="G:\\Anaconda\\pycharm_pytorch\\learning_project\\dataset_CIFAR10",train=False,transform=torchvision.transforms.ToTensor(),download=False)dataloader = DataLoader(dataset, batch_size=1)class MYNN(nn.Module):def __init__(self):super(MYNN, self).__init__()self.model1 = Sequential(Conv2d(3, 32, 5, padding=2, stride=1),MaxPool2d(2),Conv2d(32, 32, 5, padding=2, stride=1),MaxPool2d(2),Conv2d(32, 64, 5, padding=2, stride=1),MaxPool2d(2),Flatten(),Linear(1024, 64),Linear(64, 10))def forward(self, x):x = self.model1(x)return xloss = CrossEntropyLoss()
mynn = MYNN()
opitm = optim.SGD(mynn.parameters(), lr=0.01) # 优化器for data in dataloader:imgs, targets = dataoutputs = mynn(imgs)result_loss = loss(outputs, targets)# print(outputs) # 神经网络输出# print(targets) # 目标# print(result_loss) # 损失函数-交叉熵计算结果opitm.zero_grad() # 梯度清零,设置断点result_loss.backward() # 反向传播,求出每个节点的梯度,设置断点opitm.step() # 对神经网络模型的参数进行调优,设置断点
设置断点,进入程序Debug:
(1)不断运行程序,能够观察到卷积层0的bias梯度变化:mynn -> Protected Attributes -> _modules -> ‘model1’ -> Protected Attributes -> _modules -> ‘0’ -> bias-> grad

(2)能观察到卷积层0的weight梯度变化:mynn -> Protected Attributes -> _modules -> ‘model1’ -> Protected Attributes -> _modules -> ‘0’ -> weight -> grad

(3)能观察到bias的变化:mynn -> Protected Attributes -> _modules -> ‘model1’ -> Protected Attributes -> _modules -> ‘0’ -> bias-> data

(4)能观察到weight的变化:mynn -> Protected Attributes -> _modules -> ‘model1’ -> Protected Attributes -> _modules -> ‘0’ -> weight-> data

(5)结论:运行opitm.zero_grad()后,清空weight和bias的梯度grad;运行result_loss.backward()后,计算得到新的weight和bias的梯度grad;运行opitm.step()后,调整weight和bias的值。
1.1.2 优化器多次循环
修改以上python代码,增加多次循环,观察总体损失值改变。
import torchvision
import torch
from torch import nn, optim
from torch.nn import Linear, Conv2d, MaxPool2d, Flatten, Sequential,CrossEntropyLoss
from torch.utils.data import DataLoader
from torch.utils.tensorboard import SummaryWriterdataset = torchvision.datasets.CIFAR10(root="G:\\Anaconda\\pycharm_pytorch\\learning_project\\dataset_CIFAR10",train=False,transform=torchvision.transforms.ToTensor(),download=False)dataloader = DataLoader(dataset, batch_size=1)class MYNN(nn.Module):def __init__(self):super(MYNN, self).__init__()self.model1 = Sequential(Conv2d(3, 32, 5, padding=2, stride=1),MaxPool2d(2),Conv2d(32, 32, 5, padding=2, stride=1),MaxPool2d(2),Conv2d(32, 64, 5, padding=2, stride=1),MaxPool2d(2),Flatten(),Linear(1024, 64),Linear(64, 10))def forward(self, x):x = self.model1(x)return xloss = CrossEntropyLoss()
mynn = MYNN()
opitm = optim.SGD(mynn.parameters(), lr=0.01) # 优化器for epoch in range(20):running_loss = 0.0for data in dataloader:imgs, targets = dataoutputs = mynn(imgs)result_loss = loss(outputs, targets)# print(outputs) # 神经网络输出# print(targets) # 目标# print(result_loss) # 损失函数-交叉熵计算结果opitm.zero_grad() # 梯度清零result_loss.backward() # 反向传播,求出每个节点的梯度opitm.step() # 对神经网络模型的参数进行调优running_loss = running_loss + result_loss#.dataprint(running_loss)
运行结果:
tensor(18746.2012, grad_fn=<AddBackward0>)
tensor(16136.0107, grad_fn=<AddBackward0>)
tensor(15499.3203, grad_fn=<AddBackward0>)
tensor(nan, grad_fn=<AddBackward0>)
tensor(nan, grad_fn=<AddBackward0>)
tensor(nan, grad_fn=<AddBackward0>)
可以发现running_loss在一开始不断降低,但是以下的nan暂时不知道是什么原因。
相关文章:
pytorch学习(7)——神经网络优化器torch.optim
1 optim 优化器 PyTorch神经网络优化器(optimizer)通过调整神经网络的参数(weight和bias)来最小化损失函数(Loss)。 学习链接: https://pytorch.org/docs/stable/optim.html 1.1 优化器基类 使…...
leetcode做题笔记101. 对称二叉树
给你一个二叉树的根节点 root , 检查它是否轴对称。 思路一:递归 bool isSymmetric(struct TreeNode* root){if (root NULL) return true;return fun(root->left, root->right); }int fun(struct TreeNode* l_root, struct TreeNode* r_root) {…...
边缘计算相关概念--学习笔记
一.边缘计算概念 边缘计算将数据的处理,应用程序的运行甚至一些功能服务的实现,由网络中心下放到网络边缘的节点上,在网络边缘侧的智能网关上就近采集并且处理数据,不需要将大量未处理的数据上传到远程的大数据平台。边缘计算理论…...
flutter windows编译错误 flutter_assemble.vcxproj
flutter 编译windows是出现错误。 [ 44 ms] d:\Program Files\Microsoft Visual Studio\2022\Community\MSBuild\Microsoft\VC\v170\Microsoft.CppCommon.targets(248,5): error MSB8066: ��E:\work\kkview_kuaichuan\kkview_kuaichuan\build\windows\C…...
通过运行中的容器生成 Docker Compose 配置文件
背景 笔者之前有一次不小心删除了原始的 docker-compose.yml 文件,不过正在运行的 Docker 容器还在,找了许久,发现一个方法可以从这些容器中生成一个等效的 Docker Compose 配置文件。本文将介绍使用 autocompose 工具从正在运行的容器中反向…...
rancher界面无法登陆问题解决,登录超时;
1.找到rancher主机,查看日志 docker ps | grep rancher # rancher 容器 名称 jolly_ptolemy docker logs -f jolly_ptolemy 日志提示, java.sql.SQLException: Got error 28 from storage engine,磁盘满了 2.磁盘管理 df -h #查看磁盘使…...
Django(6)-django项目自动化测试
Django 应用的测试应该写在应用的 tests.py 文件里。测试系统会自动的在所有以 tests 开头的文件里寻找并执行测试代码。 我们的 polls 应用现在有一个小 bug 需要被修复:我们的要求是如果 Question 是在一天之内发布的, Question.was_published_recentl…...
【AUTOSAR】【CAN通信】CanNm
目录 一、概述 二、说明 三、功能说明 3.1 协调算法 3.2 操作模式 3.2.1 网络模式...
拼多多淘宝大量缓存商品数据用什么格式提供比较好?
众所周知,淘宝拼多多是我国主流的电商平台,其上有大量的商品数据。很多商家会通过API来访问他们的商品数据,根据API的调用次数收费。第三方数据公司提供电商数据接口API,采集实时数据。但是,在他们的服务器上有大量的缓…...
【校招VIP】前端校招考点之页面转换算法
考点介绍: 在地址映射过程中,若在页面中发现所要访问的页面不在内存中,则产生缺页中断。当发生缺页中断时,如果操作系统内存中没有空闲页面,则操作系统必须在内存选择一个页面将其移出内存,以便为即将调入的…...
android 下载网络文件
工具类 import android.app.ProgressDialog; import android.content.Context; import android.os.AsyncTask; import android.os.Environment; import android.util.Log;import java.io.BufferedInputStream; import java.io.File; import java.io.FileOutputStream; import …...
springboot定时任务:同时使用定时任务和websocket报错
背景 项目使用了websocket,实现了消息的实时推送。后来项目需要一个定时任务,使用org.springframework.scheduling.annotation的EnableScheduling注解来实现,启动项目之后报错 Bean com.alibaba.cloud.sentinel.custom.SentinelAutoConfiguration of t…...
CSS3渐变及2D转换
CSS3渐变及2D转换 持续更新哦… 1、css3渐变 概念: CSS3渐变(gradient)可以让你在两个或多个指定的颜色之间显示平 稳的过渡。以前,你必须使用图像来实现这些效果,现在通过使用 CSS3的渐变(gradients)即可实现。此外,渐变效果的元素在放大…...
无涯教程-PHP - eregi()函数
eregi() - 语法 int eregi(string pattern, string string, [array regs]); eregi()函数在pattern指定的整个字符串中搜索string指定的字符串,。搜索不区分大小写。 Eregi()在检查字符串的有效性时特别有用。 可选的输入参数regs包含一个由正则表达式中的括号分组的所有匹配…...
Spring与Mybatis整合aop整合pageHelper分页插件
前言 Spring与MyBatis整合的意义在于提供了一种结合优势的方式,以便更好地开发和管理持久层(数据库访问)代码。 这里也是总结了几点主要意义 简化配置:Spring与MyBatis整合后,可以通过Spring的配置文件来管理和配置M…...
SSL/CA 证书及其相关证书文件(pem、crt、cer、key、csr)
数字证书是网络世界中的身份证,数字证书为实现双方安全通信提供了电子认证。数字证书中含有密钥对所有者的识别信息,通过验证识别信息的真伪实现对证书持有者身份的认证。数字证书可以在网络世界中为互不见面的用户建立安全可靠的信任关系,这…...
【JavaSE】内部类
文章目录 内部类概念局部内部类匿名内部类(重点重点!!! )成员内部类静态内部类 内部类概念 可以将一个类定义在另一个类或者一个方法的内部,前者称为内部类,后者称为外部类。内部类也是封装的一…...
Django(2)-编写你的第一个 Django 应用
本教程的目的是创建一个网络投票应用程序。 它将由两部分组成: 一个让人们查看和投票的公共站点。 一个让你能添加、修改和删除投票的管理站点。 创建应用 $ python manage.py startapp polls每一个应用是一个python包,一个项目可以包含多个应用。 …...
燃气管网监测系统,24小时守护燃气安全
随着社会的发展和人民生活水平的提高,燃气逐渐成为人们日常生活和工作中不可或缺的一部分。然而,近年来,屡屡发生的燃气爆炸问题,也让人们不禁对燃气的安全性产生了担忧。因此,建立一个高效、实时、准确的燃气管网监测…...
昌硕科技、世硕电子同步上线法大大电子合同
近日,世界500强企业和硕联合旗下上海昌硕科技有限公司(以下简称“昌硕科技”)、世硕电子(昆山)有限公司(以下简称“世硕电子”)的电子签项目正式上线。上线仪式在上海浦东和硕集团科研大楼举行&…...
HTML 语义化
目录 HTML 语义化HTML5 新特性HTML 语义化的好处语义化标签的使用场景最佳实践 HTML 语义化 HTML5 新特性 标准答案: 语义化标签: <header>:页头<nav>:导航<main>:主要内容<article>&#x…...
日语学习-日语知识点小记-构建基础-JLPT-N4阶段(33):にする
日语学习-日语知识点小记-构建基础-JLPT-N4阶段(33):にする 1、前言(1)情况说明(2)工程师的信仰2、知识点(1) にする1,接续:名词+にする2,接续:疑问词+にする3,(A)は(B)にする。(2)復習:(1)复习句子(2)ために & ように(3)そう(4)にする3、…...
Zustand 状态管理库:极简而强大的解决方案
Zustand 是一个轻量级、快速和可扩展的状态管理库,特别适合 React 应用。它以简洁的 API 和高效的性能解决了 Redux 等状态管理方案中的繁琐问题。 核心优势对比 基本使用指南 1. 创建 Store // store.js import create from zustandconst useStore create((set)…...
2025年能源电力系统与流体力学国际会议 (EPSFD 2025)
2025年能源电力系统与流体力学国际会议(EPSFD 2025)将于本年度在美丽的杭州盛大召开。作为全球能源、电力系统以及流体力学领域的顶级盛会,EPSFD 2025旨在为来自世界各地的科学家、工程师和研究人员提供一个展示最新研究成果、分享实践经验及…...
ssc377d修改flash分区大小
1、flash的分区默认分配16M、 / # df -h Filesystem Size Used Available Use% Mounted on /dev/root 1.9M 1.9M 0 100% / /dev/mtdblock4 3.0M...
基础测试工具使用经验
背景 vtune,perf, nsight system等基础测试工具,都是用过的,但是没有记录,都逐渐忘了。所以写这篇博客总结记录一下,只要以后发现新的用法,就记得来编辑补充一下 perf 比较基础的用法: 先改这…...
高危文件识别的常用算法:原理、应用与企业场景
高危文件识别的常用算法:原理、应用与企业场景 高危文件识别旨在检测可能导致安全威胁的文件,如包含恶意代码、敏感数据或欺诈内容的文档,在企业协同办公环境中(如Teams、Google Workspace)尤为重要。结合大模型技术&…...
sqlserver 根据指定字符 解析拼接字符串
DECLARE LotNo NVARCHAR(50)A,B,C DECLARE xml XML ( SELECT <x> REPLACE(LotNo, ,, </x><x>) </x> ) DECLARE ErrorCode NVARCHAR(50) -- 提取 XML 中的值 SELECT value x.value(., VARCHAR(MAX))…...
Python如何给视频添加音频和字幕
在Python中,给视频添加音频和字幕可以使用电影文件处理库MoviePy和字幕处理库Subtitles。下面将详细介绍如何使用这些库来实现视频的音频和字幕添加,包括必要的代码示例和详细解释。 环境准备 在开始之前,需要安装以下Python库:…...
微信小程序云开发平台MySQL的连接方式
注:微信小程序云开发平台指的是腾讯云开发 先给结论:微信小程序云开发平台的MySQL,无法通过获取数据库连接信息的方式进行连接,连接只能通过云开发的SDK连接,具体要参考官方文档: 为什么? 因为…...
