当前位置: 首页 > news >正文

学习pytorch15 优化器

优化器

  • 官网
  • 如何构造一个优化器
  • 优化器的step方法
  • code
  • running log
    • 出现下面问题如何做反向优化?

官网

https://pytorch.org/docs/stable/optim.html

在这里插入图片描述
提问:优化器是什么 要优化什么 优化能干什么 优化是为了解决什么问题
优化模型参数

如何构造一个优化器

optimizer = optim.SGD(model.parameters(), lr=0.01, momentum=0.9)  # momentum SGD优化算法用到的参数
optimizer = optim.Adam([var1, var2], lr=0.0001)
  1. 选择一个优化器算法,如上 SGD 或者 Adam
  2. 第一个参数 需要传入模型参数
  3. 第二个及后面的参数是优化器算法特定需要的,lr 学习率基本每个优化器算法都会用到

优化器的step方法

会利用模型的梯度,根据梯度每一轮更新参数
optimizer.zero_grad() # 必须做 把上一轮计算的梯度清零,否则模型会有问题

for input, target in dataset:optimizer.zero_grad()  # 必须做 把上一轮计算的梯度清零,否则模型会有问题output = model(input)loss = loss_fn(output, target)loss.backward()optimizer.step()

or 把模型梯度包装成方法再调用

for input, target in dataset:def closure():optimizer.zero_grad()output = model(input)loss = loss_fn(output, target)loss.backward()return lossoptimizer.step(closure)

code

import torch
import torchvision
from torch import nn, optim
from torch.nn import Conv2d, MaxPool2d, Flatten, Linear, Sequential
from torch.utils.data import DataLoader
from torch.utils.tensorboard import SummaryWritertest_set = torchvision.datasets.CIFAR10("./dataset", train=False, transform=torchvision.transforms.ToTensor(),download=True)dataloader = DataLoader(test_set, batch_size=1)class MySeq(nn.Module):def __init__(self):super(MySeq, self).__init__()self.model1 = Sequential(Conv2d(3, 32, kernel_size=5, stride=1, padding=2),MaxPool2d(2),Conv2d(32, 32, kernel_size=5, stride=1, padding=2),MaxPool2d(2),Conv2d(32, 64, kernel_size=5, stride=1, padding=2),MaxPool2d(2),Flatten(),Linear(1024, 64),Linear(64, 10))def forward(self, x):x = self.model1(x)return x# 定义loss
loss = nn.CrossEntropyLoss()
# 搭建网络
myseq = MySeq()
print(myseq)
# 定义优化器
optmizer = optim.SGD(myseq.parameters(), lr=0.001, momentum=0.9)
for epoch in range(20):running_loss = 0.0for data in dataloader:imgs, targets = data# print(imgs.shape)output = myseq(imgs)optmizer.zero_grad()  # 每轮训练将梯度初始化为0  上一次的梯度对本轮参数优化没有用result_loss = loss(output, targets)result_loss.backward()  # 优化器需要每个参数的梯度, 所以要在backward() 之后执行optmizer.step()  # 根据梯度对每个参数进行调优# print(result_loss)# print(result_loss.grad)# print("ok")running_loss += result_lossprint(running_loss)

running log

loss由小变大最后到nan的解决办法:

  1. 降低学习率
  2. 使用正则化技术
  3. 增加训练数据
  4. 检查网络架构和激活函数

出现下面问题如何做反向优化?

Files already downloaded and verified
MySeq((model1): Sequential((0): Conv2d(3, 32, kernel_size=(5, 5), stride=(1, 1), padding=(2, 2))(1): MaxPool2d(kernel_size=2, stride=2, padding=0, dilation=1, ceil_mode=False)(2): Conv2d(32, 32, kernel_size=(5, 5), stride=(1, 1), padding=(2, 2))(3): MaxPool2d(kernel_size=2, stride=2, padding=0, dilation=1, ceil_mode=False)(4): Conv2d(32, 64, kernel_size=(5, 5), stride=(1, 1), padding=(2, 2))(5): MaxPool2d(kernel_size=2, stride=2, padding=0, dilation=1, ceil_mode=False)(6): Flatten(start_dim=1, end_dim=-1)(7): Linear(in_features=1024, out_features=64, bias=True)(8): Linear(in_features=64, out_features=10, bias=True))
)
tensor(18622.4551, grad_fn=<AddBackward0>)
tensor(16121.4092, grad_fn=<AddBackward0>)
tensor(15442.6416, grad_fn=<AddBackward0>)
tensor(16387.4531, grad_fn=<AddBackward0>)
tensor(18351.6152, grad_fn=<AddBackward0>)
tensor(20915.9785, grad_fn=<AddBackward0>)
tensor(23081.5254, grad_fn=<AddBackward0>)
tensor(24841.8359, grad_fn=<AddBackward0>)
tensor(25401.1602, grad_fn=<AddBackward0>)
tensor(26187.4961, grad_fn=<AddBackward0>)
tensor(28283.8633, grad_fn=<AddBackward0>)
tensor(30156.9316, grad_fn=<AddBackward0>)
tensor(nan, grad_fn=<AddBackward0>)
tensor(nan, grad_fn=<AddBackward0>)
tensor(nan, grad_fn=<AddBackward0>)
tensor(nan, grad_fn=<AddBackward0>)
tensor(nan, grad_fn=<AddBackward0>)
tensor(nan, grad_fn=<AddBackward0>)
tensor(nan, grad_fn=<AddBackward0>)
tensor(nan, grad_fn=<AddBackward0>)

相关文章:

学习pytorch15 优化器

优化器 官网如何构造一个优化器优化器的step方法coderunning log出现下面问题如何做反向优化&#xff1f; 官网 https://pytorch.org/docs/stable/optim.html 提问&#xff1a;优化器是什么 要优化什么 优化能干什么 优化是为了解决什么问题 优化模型参数 如何构造一个优化器…...

[算法日志]图论刷题 沉岛思想的运用

[算法日志]图论刷题: 沉岛思想的运用 leetcode 695 岛屿最大面积 给你一个大小为 m x n 的二进制矩阵 grid . 岛屿 是由一些相邻的 1 (代表土地) 构成的组合, 这里的「相邻」要求两个 1 必须在 水平或者竖直的四个方向上 相邻. 你可以假设 grid 的四个边缘都被 0&#xff08…...

Web服务器的搭建

网站需求&#xff1a; 1.基于域名www.openlab.com可以访问网站内容为 welcome to openlab!!! 2.给该公司创建三个网站目录分别显示学生信息&#xff0c;教学资料和缴费网站&#xff0c;基于www.openlab.com/student 网站访问学生信息&#xff0c;www.openlab.com/data网站访问教…...

如何使用 GTX750 或 1050 显卡安装 CUDA11+

前言 由于兼容性问题&#xff0c;使得我们若想用较新版本的 PyTorch&#xff0c;通过 GPU 方式训练模型&#xff0c;也得更换较新版本得 CUDA 工具包。然而 CUDA 的版本又与电脑显卡的驱动程序版本关联&#xff0c;如果是低版本的显卡驱动程序安装 CUDA11 及以上肯定会失败。 比…...

跟着森老师学React Hooks(1)——使用Vite构建React项目

Vite是一款构建工具&#xff0c;对ts有很好的支持&#xff0c;最近也是在前端越来越流行。 以往的React项目的初始化方式大多是通过脚手架create-react-app(本质是webpack)&#xff0c;其实比起Vite来构建&#xff0c;启动会慢一些。 所以这次跟着B站的一个教程&#xff0c;使用…...

强力解决使用node版本管理工具 NVM 出现的问题(找不到 node,或者找不到 npm)

强力解决使用node版本管理工具 NVM 出现的问题&#xff08;找不到 node&#xff0c;或者找不到 npm&#xff09; node与npm版本对应关系 nvm是好用的Nodejs版本管理工具&#xff0c; 通过它可以方便地在本地调换Node版本。 2020-05-28 Node当前长期稳定版12.17.0&#xff0c;…...

Docker指定容器使用内存

Docker指定容器使用内存 作者&#xff1a;铁乐与猫 如果是还没有生成的容器&#xff0c;你可以从指定镜像生成容器时特意加上 run -m 256m 或 --memory-swap512m来限制。 -m操作指定的是物理内存&#xff0c;还有虚拟交换分区默认也会生成同样的大小&#xff0c;而–memory-…...

做什么数据表格啊,要做就做数据可视化

是一堆数字更易懂&#xff0c;还是图表更易懂&#xff1f;很明显是图表&#xff0c;特别是数据可视化图表。数据可视化是一种将大量数据转化为视觉形式的过程&#xff0c;通过图形、图表、图像等方式呈现数据&#xff0c;以便更直观地理解和分析。 数据可视化更加生动、形象地…...

CSS特效003:太阳、地球、月球的旋转

GPT能够很好的应用到我们的代码开发中&#xff0c;能够提高开发速度。你可以利用其代码&#xff0c;做出一定的更改&#xff0c;然后实现效能。 css实战中&#xff0c;这种球体间的旋转&#xff0c;主要通过rotate()旋转函数来实现。实际上&#xff0c;蓝色的地球和黑色的月球…...

云计算的大模型之争,亚马逊云科技落后了?

文丨智能相对论 作者丨沈浪 “OpenAI使用了Azure的智能云服务”——在过去的半年&#xff0c;这几乎成为了微软智能云最好的广告词。 正所谓“水涨船高”&#xff0c;凭借OpenAI旗下的ChatGPT在全球范围内爆发&#xff0c;微软趁势拉了一波自家的云计算业务。2023年二季度&a…...

【form校验】3.0项目多层list嵌套

const { required, phoneOrMobile } CjmForm.rules; export default function detail() {const { query } getRouterInfo(location);const formRef useRef(null);const [crumbList, setCrumbList] useState([{url: "/wenling/Reviewer",name: "审核人员&quo…...

公共功能测试用例

1、UI测试 布局是否合理&#xff0c;输入框、按钮是否对齐 行列间距是否保持一致弹出窗口垂直居中对其界面的设计风格是否与UI的设计风格一致 系统是否使用统一风格的控件界面的文字是否简洁易懂&#xff0c;是否有错别字 兼容性测试&#xff1a;不同浏览器、版本、分辨率下&a…...

【电路笔记】-并联RLC电路分析

并联RLC电路分析 文章目录 并联RLC电路分析1、概述2、AC的行为3、替代配置3.1 带阻滤波器3.2 带通滤波器 4、总结 电子器件三个基本元件的串联行为已在我们之前的文章系列 RLC 电路分析中详细介绍。 在本文中&#xff0c;介绍了另一种称为并联 RLC 电路的关联。 在第一部分中&a…...

ros1 client

Client&#xff08;客户端&#xff09;&#xff1a;发布海龟生成请求 [类似Publisher] Serve&#xff08;服务端&#xff09;&#xff1a;海龟仿真器,接收请求 [类似于Subscriber] Service&#xff08;服务&#xff09;&#xff1a;生成海龟的具体内容&#xff0c;其中服务类型…...

射频功率放大器应用中GaN HEMT的表面电势模型

标题&#xff1a;A surface-potential based model for GaN HEMTs in RF power amplifier applications 来源&#xff1a;IEEE IEDM 2010 本文中的任何第一人称都为论文的直译 摘要&#xff1a;我们提出了第一个基于表面电位的射频GaN HEMTs紧凑模型&#xff0c;并将我们的工…...

CSP(Common Spatial Patterns)——EEG特征提取方法详解

基于CSP的运动想象 EEG 特征提取和可视化参考前文&#xff1a;https://blog.csdn.net/qq_43811536/article/details/134273470?spm1001.2014.3001.5501 目录 1. CSP是什么&#xff1f;1.1 CSP的含义1.2 CSP算法1.3 CSP特征的特点 2. CSP特征在EEG信号分类任务中的应用2.1 任务…...

【Git】Git 学习笔记_操作本地仓库

1. 安装与初始化配置 1.1 安装 下载地址 在文件夹里右键点击 git bash here 即可打开命令行面板。 git -v // 查看版本1.2 配置 git config --global user.name "heo" git config --global user.email xxxgmail.com git config --global credential.helper stor…...

杂记(3):在Pytorch中如何操作将数据集分为训练集和测试集?

在Pytorch中如何操作将数据集分为训练集和测试集&#xff1f; 0. 前言1. 手动切分2. train_test_split方法3. Pytorch自带方法4. 总结 0. 前言 数据集需要分为训练集和测试集&#xff01; 其中&#xff0c;训练集单纯用来训练&#xff0c;优化模型参数&#xff1b;测试集单纯用…...

【MySQL篇】数据库角色

前言 数据库角色是被命名的一组与数据库操作相关的权限&#xff0c;角色是权限的集合。因此&#xff0c;可以为一组具有相同权限的用户创建一个角色&#xff0c;使用角色来管理数据库权限可以简化授权的过程。 CREATE ROLE&#xff1a;创建一个角色 GRANT&#xff1a;给角色授…...

c++ 信奥赛编程 2050:【例5.20】字串包含

#include<iostream> #include<cstring> using namespace std; int main() {string str1,str2;int temp;cin>>str1>>str2;//判断长度 if(str1.size()<str2.size()){ swap(str1,str2); //交换内容 }str1str1str1; //AABCDAABCDAABCDAABCDif(str…...

【入坑系列】TiDB 强制索引在不同库下不生效问题

文章目录 背景SQL 优化情况线上SQL运行情况分析怀疑1:执行计划绑定问题?尝试:SHOW WARNINGS 查看警告探索 TiDB 的 USE_INDEX 写法Hint 不生效问题排查解决参考背景 项目中使用 TiDB 数据库,并对 SQL 进行优化了,添加了强制索引。 UAT 环境已经生效,但 PROD 环境强制索…...

渗透实战PortSwigger靶场-XSS Lab 14:大多数标签和属性被阻止

<script>标签被拦截 我们需要把全部可用的 tag 和 event 进行暴力破解 XSS cheat sheet&#xff1a; https://portswigger.net/web-security/cross-site-scripting/cheat-sheet 通过爆破发现body可以用 再把全部 events 放进去爆破 这些 event 全部可用 <body onres…...

YSYX学习记录(八)

C语言&#xff0c;练习0&#xff1a; 先创建一个文件夹&#xff0c;我用的是物理机&#xff1a; 安装build-essential 练习1&#xff1a; 我注释掉了 #include <stdio.h> 出现下面错误 在你的文本编辑器中打开ex1文件&#xff0c;随机修改或删除一部分&#xff0c;之后…...

VTK如何让部分单位不可见

最近遇到一个需求&#xff0c;需要让一个vtkDataSet中的部分单元不可见&#xff0c;查阅了一些资料大概有以下几种方式 1.通过颜色映射表来进行&#xff0c;是最正规的做法 vtkNew<vtkLookupTable> lut; //值为0不显示&#xff0c;主要是最后一个参数&#xff0c;透明度…...

DBAPI如何优雅的获取单条数据

API如何优雅的获取单条数据 案例一 对于查询类API&#xff0c;查询的是单条数据&#xff0c;比如根据主键ID查询用户信息&#xff0c;sql如下&#xff1a; select id, name, age from user where id #{id}API默认返回的数据格式是多条的&#xff0c;如下&#xff1a; {&qu…...

Spring AI与Spring Modulith核心技术解析

Spring AI核心架构解析 Spring AI&#xff08;https://spring.io/projects/spring-ai&#xff09;作为Spring生态中的AI集成框架&#xff0c;其核心设计理念是通过模块化架构降低AI应用的开发复杂度。与Python生态中的LangChain/LlamaIndex等工具类似&#xff0c;但特别为多语…...

佰力博科技与您探讨热释电测量的几种方法

热释电的测量主要涉及热释电系数的测定&#xff0c;这是表征热释电材料性能的重要参数。热释电系数的测量方法主要包括静态法、动态法和积分电荷法。其中&#xff0c;积分电荷法最为常用&#xff0c;其原理是通过测量在电容器上积累的热释电电荷&#xff0c;从而确定热释电系数…...

代码随想录刷题day30

1、零钱兑换II 给你一个整数数组 coins 表示不同面额的硬币&#xff0c;另给一个整数 amount 表示总金额。 请你计算并返回可以凑成总金额的硬币组合数。如果任何硬币组合都无法凑出总金额&#xff0c;返回 0 。 假设每一种面额的硬币有无限个。 题目数据保证结果符合 32 位带…...

动态 Web 开发技术入门篇

一、HTTP 协议核心 1.1 HTTP 基础 协议全称 &#xff1a;HyperText Transfer Protocol&#xff08;超文本传输协议&#xff09; 默认端口 &#xff1a;HTTP 使用 80 端口&#xff0c;HTTPS 使用 443 端口。 请求方法 &#xff1a; GET &#xff1a;用于获取资源&#xff0c;…...

GitFlow 工作模式(详解)

今天再学项目的过程中遇到使用gitflow模式管理代码&#xff0c;因此进行学习并且发布关于gitflow的一些思考 Git与GitFlow模式 我们在写代码的时候通常会进行网上保存&#xff0c;无论是github还是gittee&#xff0c;都是一种基于git去保存代码的形式&#xff0c;这样保存代码…...