学习pytorch15 优化器
优化器
- 官网
- 如何构造一个优化器
- 优化器的step方法
- code
- running log
- 出现下面问题如何做反向优化?
官网
https://pytorch.org/docs/stable/optim.html

提问:优化器是什么 要优化什么 优化能干什么 优化是为了解决什么问题
优化模型参数
如何构造一个优化器
optimizer = optim.SGD(model.parameters(), lr=0.01, momentum=0.9) # momentum SGD优化算法用到的参数
optimizer = optim.Adam([var1, var2], lr=0.0001)
- 选择一个优化器算法,如上 SGD 或者 Adam
- 第一个参数 需要传入模型参数
- 第二个及后面的参数是优化器算法特定需要的,lr 学习率基本每个优化器算法都会用到
优化器的step方法
会利用模型的梯度,根据梯度每一轮更新参数
optimizer.zero_grad() # 必须做 把上一轮计算的梯度清零,否则模型会有问题
for input, target in dataset:optimizer.zero_grad() # 必须做 把上一轮计算的梯度清零,否则模型会有问题output = model(input)loss = loss_fn(output, target)loss.backward()optimizer.step()
or 把模型梯度包装成方法再调用
for input, target in dataset:def closure():optimizer.zero_grad()output = model(input)loss = loss_fn(output, target)loss.backward()return lossoptimizer.step(closure)
code
import torch
import torchvision
from torch import nn, optim
from torch.nn import Conv2d, MaxPool2d, Flatten, Linear, Sequential
from torch.utils.data import DataLoader
from torch.utils.tensorboard import SummaryWritertest_set = torchvision.datasets.CIFAR10("./dataset", train=False, transform=torchvision.transforms.ToTensor(),download=True)dataloader = DataLoader(test_set, batch_size=1)class MySeq(nn.Module):def __init__(self):super(MySeq, self).__init__()self.model1 = Sequential(Conv2d(3, 32, kernel_size=5, stride=1, padding=2),MaxPool2d(2),Conv2d(32, 32, kernel_size=5, stride=1, padding=2),MaxPool2d(2),Conv2d(32, 64, kernel_size=5, stride=1, padding=2),MaxPool2d(2),Flatten(),Linear(1024, 64),Linear(64, 10))def forward(self, x):x = self.model1(x)return x# 定义loss
loss = nn.CrossEntropyLoss()
# 搭建网络
myseq = MySeq()
print(myseq)
# 定义优化器
optmizer = optim.SGD(myseq.parameters(), lr=0.001, momentum=0.9)
for epoch in range(20):running_loss = 0.0for data in dataloader:imgs, targets = data# print(imgs.shape)output = myseq(imgs)optmizer.zero_grad() # 每轮训练将梯度初始化为0 上一次的梯度对本轮参数优化没有用result_loss = loss(output, targets)result_loss.backward() # 优化器需要每个参数的梯度, 所以要在backward() 之后执行optmizer.step() # 根据梯度对每个参数进行调优# print(result_loss)# print(result_loss.grad)# print("ok")running_loss += result_lossprint(running_loss)
running log
loss由小变大最后到nan的解决办法:
- 降低学习率
- 使用正则化技术
- 增加训练数据
- 检查网络架构和激活函数
出现下面问题如何做反向优化?
Files already downloaded and verified
MySeq((model1): Sequential((0): Conv2d(3, 32, kernel_size=(5, 5), stride=(1, 1), padding=(2, 2))(1): MaxPool2d(kernel_size=2, stride=2, padding=0, dilation=1, ceil_mode=False)(2): Conv2d(32, 32, kernel_size=(5, 5), stride=(1, 1), padding=(2, 2))(3): MaxPool2d(kernel_size=2, stride=2, padding=0, dilation=1, ceil_mode=False)(4): Conv2d(32, 64, kernel_size=(5, 5), stride=(1, 1), padding=(2, 2))(5): MaxPool2d(kernel_size=2, stride=2, padding=0, dilation=1, ceil_mode=False)(6): Flatten(start_dim=1, end_dim=-1)(7): Linear(in_features=1024, out_features=64, bias=True)(8): Linear(in_features=64, out_features=10, bias=True))
)
tensor(18622.4551, grad_fn=<AddBackward0>)
tensor(16121.4092, grad_fn=<AddBackward0>)
tensor(15442.6416, grad_fn=<AddBackward0>)
tensor(16387.4531, grad_fn=<AddBackward0>)
tensor(18351.6152, grad_fn=<AddBackward0>)
tensor(20915.9785, grad_fn=<AddBackward0>)
tensor(23081.5254, grad_fn=<AddBackward0>)
tensor(24841.8359, grad_fn=<AddBackward0>)
tensor(25401.1602, grad_fn=<AddBackward0>)
tensor(26187.4961, grad_fn=<AddBackward0>)
tensor(28283.8633, grad_fn=<AddBackward0>)
tensor(30156.9316, grad_fn=<AddBackward0>)
tensor(nan, grad_fn=<AddBackward0>)
tensor(nan, grad_fn=<AddBackward0>)
tensor(nan, grad_fn=<AddBackward0>)
tensor(nan, grad_fn=<AddBackward0>)
tensor(nan, grad_fn=<AddBackward0>)
tensor(nan, grad_fn=<AddBackward0>)
tensor(nan, grad_fn=<AddBackward0>)
tensor(nan, grad_fn=<AddBackward0>)
相关文章:
学习pytorch15 优化器
优化器 官网如何构造一个优化器优化器的step方法coderunning log出现下面问题如何做反向优化? 官网 https://pytorch.org/docs/stable/optim.html 提问:优化器是什么 要优化什么 优化能干什么 优化是为了解决什么问题 优化模型参数 如何构造一个优化器…...
[算法日志]图论刷题 沉岛思想的运用
[算法日志]图论刷题: 沉岛思想的运用 leetcode 695 岛屿最大面积 给你一个大小为 m x n 的二进制矩阵 grid . 岛屿 是由一些相邻的 1 (代表土地) 构成的组合, 这里的「相邻」要求两个 1 必须在 水平或者竖直的四个方向上 相邻. 你可以假设 grid 的四个边缘都被 0(…...
Web服务器的搭建
网站需求: 1.基于域名www.openlab.com可以访问网站内容为 welcome to openlab!!! 2.给该公司创建三个网站目录分别显示学生信息,教学资料和缴费网站,基于www.openlab.com/student 网站访问学生信息,www.openlab.com/data网站访问教…...
如何使用 GTX750 或 1050 显卡安装 CUDA11+
前言 由于兼容性问题,使得我们若想用较新版本的 PyTorch,通过 GPU 方式训练模型,也得更换较新版本得 CUDA 工具包。然而 CUDA 的版本又与电脑显卡的驱动程序版本关联,如果是低版本的显卡驱动程序安装 CUDA11 及以上肯定会失败。 比…...
跟着森老师学React Hooks(1)——使用Vite构建React项目
Vite是一款构建工具,对ts有很好的支持,最近也是在前端越来越流行。 以往的React项目的初始化方式大多是通过脚手架create-react-app(本质是webpack),其实比起Vite来构建,启动会慢一些。 所以这次跟着B站的一个教程,使用…...
强力解决使用node版本管理工具 NVM 出现的问题(找不到 node,或者找不到 npm)
强力解决使用node版本管理工具 NVM 出现的问题(找不到 node,或者找不到 npm) node与npm版本对应关系 nvm是好用的Nodejs版本管理工具, 通过它可以方便地在本地调换Node版本。 2020-05-28 Node当前长期稳定版12.17.0,…...
Docker指定容器使用内存
Docker指定容器使用内存 作者:铁乐与猫 如果是还没有生成的容器,你可以从指定镜像生成容器时特意加上 run -m 256m 或 --memory-swap512m来限制。 -m操作指定的是物理内存,还有虚拟交换分区默认也会生成同样的大小,而–memory-…...
做什么数据表格啊,要做就做数据可视化
是一堆数字更易懂,还是图表更易懂?很明显是图表,特别是数据可视化图表。数据可视化是一种将大量数据转化为视觉形式的过程,通过图形、图表、图像等方式呈现数据,以便更直观地理解和分析。 数据可视化更加生动、形象地…...
CSS特效003:太阳、地球、月球的旋转
GPT能够很好的应用到我们的代码开发中,能够提高开发速度。你可以利用其代码,做出一定的更改,然后实现效能。 css实战中,这种球体间的旋转,主要通过rotate()旋转函数来实现。实际上,蓝色的地球和黑色的月球…...
云计算的大模型之争,亚马逊云科技落后了?
文丨智能相对论 作者丨沈浪 “OpenAI使用了Azure的智能云服务”——在过去的半年,这几乎成为了微软智能云最好的广告词。 正所谓“水涨船高”,凭借OpenAI旗下的ChatGPT在全球范围内爆发,微软趁势拉了一波自家的云计算业务。2023年二季度&a…...
【form校验】3.0项目多层list嵌套
const { required, phoneOrMobile } CjmForm.rules; export default function detail() {const { query } getRouterInfo(location);const formRef useRef(null);const [crumbList, setCrumbList] useState([{url: "/wenling/Reviewer",name: "审核人员&quo…...
公共功能测试用例
1、UI测试 布局是否合理,输入框、按钮是否对齐 行列间距是否保持一致弹出窗口垂直居中对其界面的设计风格是否与UI的设计风格一致 系统是否使用统一风格的控件界面的文字是否简洁易懂,是否有错别字 兼容性测试:不同浏览器、版本、分辨率下&a…...
【电路笔记】-并联RLC电路分析
并联RLC电路分析 文章目录 并联RLC电路分析1、概述2、AC的行为3、替代配置3.1 带阻滤波器3.2 带通滤波器 4、总结 电子器件三个基本元件的串联行为已在我们之前的文章系列 RLC 电路分析中详细介绍。 在本文中,介绍了另一种称为并联 RLC 电路的关联。 在第一部分中&a…...
ros1 client
Client(客户端):发布海龟生成请求 [类似Publisher] Serve(服务端):海龟仿真器,接收请求 [类似于Subscriber] Service(服务):生成海龟的具体内容,其中服务类型…...
射频功率放大器应用中GaN HEMT的表面电势模型
标题:A surface-potential based model for GaN HEMTs in RF power amplifier applications 来源:IEEE IEDM 2010 本文中的任何第一人称都为论文的直译 摘要:我们提出了第一个基于表面电位的射频GaN HEMTs紧凑模型,并将我们的工…...
CSP(Common Spatial Patterns)——EEG特征提取方法详解
基于CSP的运动想象 EEG 特征提取和可视化参考前文:https://blog.csdn.net/qq_43811536/article/details/134273470?spm1001.2014.3001.5501 目录 1. CSP是什么?1.1 CSP的含义1.2 CSP算法1.3 CSP特征的特点 2. CSP特征在EEG信号分类任务中的应用2.1 任务…...
【Git】Git 学习笔记_操作本地仓库
1. 安装与初始化配置 1.1 安装 下载地址 在文件夹里右键点击 git bash here 即可打开命令行面板。 git -v // 查看版本1.2 配置 git config --global user.name "heo" git config --global user.email xxxgmail.com git config --global credential.helper stor…...
杂记(3):在Pytorch中如何操作将数据集分为训练集和测试集?
在Pytorch中如何操作将数据集分为训练集和测试集? 0. 前言1. 手动切分2. train_test_split方法3. Pytorch自带方法4. 总结 0. 前言 数据集需要分为训练集和测试集! 其中,训练集单纯用来训练,优化模型参数;测试集单纯用…...
【MySQL篇】数据库角色
前言 数据库角色是被命名的一组与数据库操作相关的权限,角色是权限的集合。因此,可以为一组具有相同权限的用户创建一个角色,使用角色来管理数据库权限可以简化授权的过程。 CREATE ROLE:创建一个角色 GRANT:给角色授…...
c++ 信奥赛编程 2050:【例5.20】字串包含
#include<iostream> #include<cstring> using namespace std; int main() {string str1,str2;int temp;cin>>str1>>str2;//判断长度 if(str1.size()<str2.size()){ swap(str1,str2); //交换内容 }str1str1str1; //AABCDAABCDAABCDAABCDif(str…...
Docker 离线安装指南
参考文章 1、确认操作系统类型及内核版本 Docker依赖于Linux内核的一些特性,不同版本的Docker对内核版本有不同要求。例如,Docker 17.06及之后的版本通常需要Linux内核3.10及以上版本,Docker17.09及更高版本对应Linux内核4.9.x及更高版本。…...
生成xcframework
打包 XCFramework 的方法 XCFramework 是苹果推出的一种多平台二进制分发格式,可以包含多个架构和平台的代码。打包 XCFramework 通常用于分发库或框架。 使用 Xcode 命令行工具打包 通过 xcodebuild 命令可以打包 XCFramework。确保项目已经配置好需要支持的平台…...
【根据当天日期输出明天的日期(需对闰年做判定)。】2022-5-15
缘由根据当天日期输出明天的日期(需对闰年做判定)。日期类型结构体如下: struct data{ int year; int month; int day;};-编程语言-CSDN问答 struct mdata{ int year; int month; int day; }mdata; int 天数(int year, int month) {switch (month){case 1: case 3:…...
聊聊 Pulsar:Producer 源码解析
一、前言 Apache Pulsar 是一个企业级的开源分布式消息传递平台,以其高性能、可扩展性和存储计算分离架构在消息队列和流处理领域独树一帜。在 Pulsar 的核心架构中,Producer(生产者) 是连接客户端应用与消息队列的第一步。生产者…...
江苏艾立泰跨国资源接力:废料变黄金的绿色供应链革命
在华东塑料包装行业面临限塑令深度调整的背景下,江苏艾立泰以一场跨国资源接力的创新实践,重新定义了绿色供应链的边界。 跨国回收网络:废料变黄金的全球棋局 艾立泰在欧洲、东南亚建立再生塑料回收点,将海外废弃包装箱通过标准…...
苍穹外卖--缓存菜品
1.问题说明 用户端小程序展示的菜品数据都是通过查询数据库获得,如果用户端访问量比较大,数据库访问压力随之增大 2.实现思路 通过Redis来缓存菜品数据,减少数据库查询操作。 缓存逻辑分析: ①每个分类下的菜品保持一份缓存数据…...
C# 类和继承(抽象类)
抽象类 抽象类是指设计为被继承的类。抽象类只能被用作其他类的基类。 不能创建抽象类的实例。抽象类使用abstract修饰符声明。 抽象类可以包含抽象成员或普通的非抽象成员。抽象类的成员可以是抽象成员和普通带 实现的成员的任意组合。抽象类自己可以派生自另一个抽象类。例…...
3403. 从盒子中找出字典序最大的字符串 I
3403. 从盒子中找出字典序最大的字符串 I 题目链接:3403. 从盒子中找出字典序最大的字符串 I 代码如下: class Solution { public:string answerString(string word, int numFriends) {if (numFriends 1) {return word;}string res;for (int i 0;i &…...
在WSL2的Ubuntu镜像中安装Docker
Docker官网链接: https://docs.docker.com/engine/install/ubuntu/ 1、运行以下命令卸载所有冲突的软件包: for pkg in docker.io docker-doc docker-compose docker-compose-v2 podman-docker containerd runc; do sudo apt-get remove $pkg; done2、设置Docker…...
分布式增量爬虫实现方案
之前我们在讨论的是分布式爬虫如何实现增量爬取。增量爬虫的目标是只爬取新产生或发生变化的页面,避免重复抓取,以节省资源和时间。 在分布式环境下,增量爬虫的实现需要考虑多个爬虫节点之间的协调和去重。 另一种思路:将增量判…...
