当前位置: 首页 > news >正文

pytorch学习(十)优化函数

优化函数主要有,SGD, Adam,RMSProp这三种,并且有lr学习率,momentum动量,betas等参数需要设置。

通过这篇文章,可以学到pytorch中的优化函数的使用。

1.代码

代码参考《python深度学习-基于pytorch》,改了一下网络结构,其他没变化。

import torch
import torch.utils.data as Data
import torch.nn.functional as Func
import matplotlib.pyplot as pltLR =0.01
BATCH_SIZE = 20
EPOCH = 12#生成数据
#将一维变成二维数据
x = torch.unsqueeze(torch.linspace(-1,1,1000),dim=1)
y = x.pow(2) + 0.1 * torch.normal(torch.zeros(*x.size()))A = x.size()
B = x.size()torch_dataset = Data.TensorDataset(x,y)
data_loader = Data.DataLoader(dataset=torch_dataset,batch_size= BATCH_SIZE,shuffle=False)class Net(torch.nn.Module):def __init__(self):super(Net, self).__init__()self.hidden1 = torch.nn.Linear(1,20)self.hidden2 = torch.nn.Linear(20, 40)self.predict = torch.nn.Linear(40,1)def forward(self,x):x = Func.relu(self.hidden1(x))x = Func.relu(self.hidden2(x))x = self.predict(x)return xnet_SGD = Net()
net_Momentum = Net()
net_PMSProp = Net()
NET_Adam = Net()nets = {net_SGD,net_Momentum,net_PMSProp,NET_Adam }opt_SGD = torch.optim.SGD(net_SGD.parameters(),lr=LR)
opt_Momentum = torch.optim.SGD(net_Momentum.parameters(),momentum=0.3,lr=LR)
opt_PMSProp = torch.optim.RMSprop(net_PMSProp.parameters())
opt_Adam = torch.optim.Adam(NET_Adam.parameters(),lr=LR,betas=(0.9,0.99))optimizers = {opt_SGD,opt_Momentum, opt_PMSProp, opt_Adam}loss_func = torch.nn.MSELoss()loss_his =[[],[],[],[]]for epoch in range((EPOCH)):for step,(batch_x,batch_y) in enumerate(data_loader):for net, opt, l_his, in zip(nets, optimizers,loss_his):output = net(batch_x)loss = loss_func(output,batch_y)opt.zero_grad()loss.backward()opt.step()l_his.append(loss.data.numpy())
labels = ['SGD','SGD_Momentum','RMSProp','Adam']#可视化
for i, l_his in enumerate(loss_his):plt.plot(l_his,label = labels[i])
plt.legend(loc='best')
plt.xlabel('steps')
plt.ylabel('loss')
# plt.ylim((0,0.8))
plt.show()

2.结果

通过测试,发现每一次的结果都不一样,每一次结果的显示图也不一样。因为shuffle=True。

为shuffle=True时候显示的其中一个结果为:

当shuffle为False时,发现也不稳定,其中一张结果显示为:

3.大家copy代码后,可以调一调batch_size,lr,momentum,betas等参数。

其中lr动态修改学习率的代码为:

print(opt_SGD.param_groups)
opt_SGD.param_groups[0]['lr']*=0.1
opt_SGD.param_groups[0]['momentum']

相关文章:

pytorch学习(十)优化函数

优化函数主要有,SGD, Adam,RMSProp这三种,并且有lr学习率,momentum动量,betas等参数需要设置。 通过这篇文章,可以学到pytorch中的优化函数的使用。 1.代码 代码参考《python深度学习-基于pytorch》&…...

Ubuntu22.04:安装Samba

1.安装Samba服务 $ sudo apt install samba samba-common 2.创建共享目录 $ mkdir /home/xxx/samba $ chmod 777 /home/xxx/samba 3.将用户加入到Samba服务中 $ sudo smbpasswd -a xxx 设置用户xxx访问Samba的密码 4.配置Samba服务 $ sudo vi /etc/samba/smb.conf 在最后加入 …...

Powershell 使用介绍

0 Preface/Foreword 0.1 参考文档 Starting Windows PowerShell - PowerShell | Microsoft Learn 1 Powershell 介绍 2 命令介绍 2.1 新建文件夹 New-Item -Path C:\GitLab-Runner -ItemType Directory 2.2 切换路径 cd C:\GitLab-Runner 2.3 下载文件 Invoke-WebRequ…...

【Langchain大语言模型开发教程】记忆

🔗 LangChain for LLM Application Development - DeepLearning.AI 学习目标 1、Langchain的历史记忆 ConversationBufferMemory 2、基于窗口限制的临时记忆 ConversationBufferWindowMemory 3、基于Token数量的临时记忆 ConversationTokenBufferMemory 4、基于历史…...

最新Qt6的下载与成功安装详细介绍

引言 Qt6 是一款强大的跨平台应用程序开发框架,支持多种编程语言,最常用的是C。Qt6带来了许多改进和新功能,包括对C17的支持、增强的QML和UI技术、新的图形架构,以及构建系统方面的革新。本文将指导你如何在Windows平台上下载和安…...

LeetCode 热题 HOT 100 (001/100)【宇宙最简单版】

【链表】 No. 0160 相交链表 【简单】👉力扣对应题目指路 希望对你有帮助呀!!💜💜 如有更好理解的思路,欢迎大家留言补充 ~ 一起加油叭 💦 欢迎关注、订阅专栏 【力扣详解】谢谢你的支持&#x…...

Ubantu 使用 docker 配置 + 远程部署 + 远程开发

大家好我是苏麟 , Ubantu 一些配置 . 视频 : 服务器很贵?搞台虚拟机玩玩!保姆级 Linux 远程开发教程_哔哩哔哩_bilibili Docker安装及配置 安装命令 : sudo apt install docker.io 查看版本号 : docker -v 查看虚拟机地址命令 : ifconfig 虚拟机地址 或…...

应用层自定义协议与序列化

个人主页:Lei宝啊 愿所有美好如期而遇 协议 简单来说,就是通信双方约定好的结构化的数据。 序列化与反序列化 我们通过一个问题引入这个概念,假如我们要实现一个网络版的计算器,那么现在有两种方案,第一种&#x…...

Python学习笔记—100页Opencv详细讲解教程

目录 1 创建和显示窗口... - 4 - 2 加载显示图片... - 6 - 3 保存图片... - 7 - 4 视频采集... - 8 - 5视频录制... - 11 - 6 控制鼠标... - 12 - 7 TrackBar 控件... - 14 - 8.RGB和BGR颜色空间... - 16 - 9.HSV和HSL和YUV.. - 17 - 10 颜色空间的转化... - 18 - …...

C语言·分支和循环语句(超详细系列·全面总结)

前言:Hello大家好😘,我是心跳sy,为了更好地形成一个学习c语言的体系,最近将会更新关于c语言语法基础的知识,今天更新一下分支循环语句的知识点,我们一起来看看吧~ 目录 一、什么是语句&#xf…...

Gateway源码分析:路由Route、断言Predicate、Filter

文章目录 源码总流程图说明GateWayAutoConfigurationDispatcherHandlergetHandler()handleRequestWith()RouteToRequestUrlFilterReactiveLoadBalancerClientFilterNettyRoutingFilter 补充知识适配器模式 详细流程图 源码总流程图 在线总流程图 说明 Gateway的版本使用的是…...

ARM体系结构和接口技术(十)按键中断实验①

一、按键中断实验 (一)分析按键电路图 (二)芯片手册 二、按键中断实验分析 注:NVIC----Cortx-M核GIC----Cortx-A核 (一)查看所有外设的总线以及寄存器基地址 注:GIC的总线是A7核的…...

PostgreSQL使用(二)——插入、更新、删除数据

说明:本文介绍PostgreSQL的DML语言; 插入数据 -- 1.全字段插入,字段名可以省略 insert into tb_student values (1, 张三, 1990-01-01, 88.88);-- 2.部分字段插入,字段名必须写全 insert into tb_student (id, name) values (2,…...

有关css的题目

css样式来源有哪些&#xff1f; 内联样式&#xff1a; <a style"color: red"> </a> 内部样式&#xff1a;<style></style> 外部样式&#xff1a;写在独立的 .css文件中的 浏览器的默认样式 display有哪些属性 none - 不展示 block - 块类型…...

【开源库】libodb库编译及使用

前言 本文介绍windows平台下libodb库的编译及使用。 文末提供libodb-2.4.0编译好的msvc2019_64版本&#xff0c;可直接跳转自取 ODB库学习相关 【开源库学习】libodb库学习&#xff08;一&#xff09; 【开源库学习】libodb库学习&#xff08;二&#xff09; 【开源库学习】…...

电力需求预测挑战赛笔记 Task3 #Datawhale AI 夏令营

上文&#xff1a; 电力需求预测挑战赛笔记 Task2 #Datawhale AI 夏令营-CSDN博客文章浏览阅读80次。【代码】电力需求预测挑战赛笔记 Task2。https://blog.csdn.net/qq_23311271/article/details/140360632 前面我们介绍了如何使用经验模型以及常见的lightgbm决策树模型来解决…...

Promise 详解(原理篇)

目录 什么是 Promise 实现一个 Promise Promise 的声明 解决基本状态 添加 then 方法 解决异步实现 解决链式调用 完成 resolvePromise 函数 解决其他问题 添加 catch 方法 添加 finally 方法 添加 resolve、reject、race、all 等方法 如何验证我们的 Promise 是否…...

动态内存经典笔试题分析

目录 1.题目一 2.题目二 3.题目三 4.题目四 1.题目一 #include<stdlib.h> #include<stdio.h> #include<string.h> void GetMemory(char* p) {p (char*)malloc(100); } void Test(void) {char* str NULL;GetMemory(str);strcpy(str, "hello world…...

JS设计模式(一)单例模式

注释很详细&#xff0c;直接上代码 本文建立在已有JS面向对象基础的前提下&#xff0c;若无&#xff0c;请移步以下博客先行了解 JS面向对象&#xff08;一&#xff09;类与对象写法 特点和用途&#xff1a; 全局访问点&#xff1a;通过单例模式可以在整个应用程序中访问同一个…...

uniapp动态计算并设置元素高度

<template><view><scroll-view id"sv-box" :scroll-y"true" :style"{height:navHeightpx}"></scroll-view><view id"btn-box"><button>取消</button><button>确认</button><…...

铭豹扩展坞 USB转网口 突然无法识别解决方法

当 USB 转网口扩展坞在一台笔记本上无法识别,但在其他电脑上正常工作时,问题通常出在笔记本自身或其与扩展坞的兼容性上。以下是系统化的定位思路和排查步骤,帮助你快速找到故障原因: 背景: 一个M-pard(铭豹)扩展坞的网卡突然无法识别了,扩展出来的三个USB接口正常。…...

全球首个30米分辨率湿地数据集(2000—2022)

数据简介 今天我们分享的数据是全球30米分辨率湿地数据集&#xff0c;包含8种湿地亚类&#xff0c;该数据以0.5X0.5的瓦片存储&#xff0c;我们整理了所有属于中国的瓦片名称与其对应省份&#xff0c;方便大家研究使用。 该数据集作为全球首个30米分辨率、覆盖2000–2022年时间…...

转转集团旗下首家二手多品类循环仓店“超级转转”开业

6月9日&#xff0c;国内领先的循环经济企业转转集团旗下首家二手多品类循环仓店“超级转转”正式开业。 转转集团创始人兼CEO黄炜、转转循环时尚发起人朱珠、转转集团COO兼红布林CEO胡伟琨、王府井集团副总裁祝捷等出席了开业剪彩仪式。 据「TMT星球」了解&#xff0c;“超级…...

Nginx server_name 配置说明

Nginx 是一个高性能的反向代理和负载均衡服务器&#xff0c;其核心配置之一是 server 块中的 server_name 指令。server_name 决定了 Nginx 如何根据客户端请求的 Host 头匹配对应的虚拟主机&#xff08;Virtual Host&#xff09;。 1. 简介 Nginx 使用 server_name 指令来确定…...

USB Over IP专用硬件的5个特点

USB over IP技术通过将USB协议数据封装在标准TCP/IP网络数据包中&#xff0c;从根本上改变了USB连接。这允许客户端通过局域网或广域网远程访问和控制物理连接到服务器的USB设备&#xff08;如专用硬件设备&#xff09;&#xff0c;从而消除了直接物理连接的需要。USB over IP的…...

Pinocchio 库详解及其在足式机器人上的应用

Pinocchio 库详解及其在足式机器人上的应用 Pinocchio (Pinocchio is not only a nose) 是一个开源的 C 库&#xff0c;专门用于快速计算机器人模型的正向运动学、逆向运动学、雅可比矩阵、动力学和动力学导数。它主要关注效率和准确性&#xff0c;并提供了一个通用的框架&…...

Linux C语言网络编程详细入门教程:如何一步步实现TCP服务端与客户端通信

文章目录 Linux C语言网络编程详细入门教程&#xff1a;如何一步步实现TCP服务端与客户端通信前言一、网络通信基础概念二、服务端与客户端的完整流程图解三、每一步的详细讲解和代码示例1. 创建Socket&#xff08;服务端和客户端都要&#xff09;2. 绑定本地地址和端口&#x…...

使用Spring AI和MCP协议构建图片搜索服务

目录 使用Spring AI和MCP协议构建图片搜索服务 引言 技术栈概览 项目架构设计 架构图 服务端开发 1. 创建Spring Boot项目 2. 实现图片搜索工具 3. 配置传输模式 Stdio模式&#xff08;本地调用&#xff09; SSE模式&#xff08;远程调用&#xff09; 4. 注册工具提…...

scikit-learn机器学习

# 同时添加如下代码, 这样每次环境(kernel)启动的时候只要运行下方代码即可: # Also add the following code, # so that every time the environment (kernel) starts, # just run the following code: import sys sys.path.append(/home/aistudio/external-libraries)机…...

Qemu arm操作系统开发环境

使用qemu虚拟arm硬件比较合适。 步骤如下&#xff1a; 安装qemu apt install qemu-system安装aarch64-none-elf-gcc 需要手动下载&#xff0c;下载地址&#xff1a;https://developer.arm.com/-/media/Files/downloads/gnu/13.2.rel1/binrel/arm-gnu-toolchain-13.2.rel1-x…...