当前位置: 首页 > news >正文

深度学习笔记:不同的反向传播迭代方法

1 随机梯度下降法SGD
在这里插入图片描述

随机梯度下降法每次迭代取梯度下降最大的方向更新。这一方法实现简单,但是在很多函数中,梯度下降的方向不一定指向函数最低点,这使得梯度下降呈现“之”字形,其效率较低
在这里插入图片描述

class SGD:"""随机梯度下降法(Stochastic Gradient Descent)"""def __init__(self, lr=0.01):self.lr = lrdef update(self, params, grads):for key in params.keys():params[key] -= self.lr * grads[key] 

2 Momentum

在这里插入图片描述
momentum即动量。该方法设置变量v代表梯度下降的速度,其中dL/dW(梯度值)代表改变速度的“受力”,而α则作为“阻力”,限制v变化。该方法进行梯度下降可以类比一个小球在三维平面上滚动。

在下面的示例中,可以看到虽然迭代方向还是呈“之”字形,但是在x方向,虽然梯度较小,但是由于受力始终在一个方向,速度逐渐加快。在y方向,虽然梯度大,但上下受力相反,使得y方向不会有很大偏移

在这里插入图片描述

class Momentum:"""Momentum SGD"""def __init__(self, lr=0.01, momentum=0.9):self.lr = lrself.momentum = momentumself.v = Nonedef update(self, params, grads):if self.v is None:self.v = {}for key, val in params.items():                                self.v[key] = np.zeros_like(val)for key in params.keys():self.v[key] = self.momentum*self.v[key] - self.lr*grads[key] params[key] += self.v[key]

在程序里一开始v设为None,在第一次调用update时会将v更新为和各权重形状一样的0矩阵

3 AdaGrad

在这里插入图片描述
AdaGrad的思路是根据上一轮迭代的变化量动态调整每一个权重的学习率。一个权重在迭代中变化量越大,其在下一轮中学习率就会减少更多。

在公式中,我们用h记录过去所有梯度的平方和(⊙代表矩阵元素相乘),在更新权重时之前变化较大的权重值变化量会变小。

由于h是不断累加的平方和,如果学习一直持续下去,W更新率会不断趋于0,要改善这一问题可以参考RMSProp,该方法会对较早更新的梯度逐渐“遗忘”,而更多反应新更新的状态

AdaGrad

class AdaGrad:"""AdaGrad"""def __init__(self, lr=0.01):self.lr = lrself.h = Nonedef update(self, params, grads):if self.h is None:self.h = {}for key, val in params.items():self.h[key] = np.zeros_like(val)for key in params.keys():self.h[key] += grads[key] * grads[key]params[key] -= self.lr * grads[key] / (np.sqrt(self.h[key]) + 1e-7)

在这里注意我们在h的每个元素中加上了微小的1e-7,这是为了防止h中有元素为0时,作为除数会报错。

RMSProp

class RMSprop:"""RMSprop"""def __init__(self, lr=0.01, decay_rate = 0.99):self.lr = lrself.decay_rate = decay_rateself.h = Nonedef update(self, params, grads):if self.h is None:self.h = {}for key, val in params.items():self.h[key] = np.zeros_like(val)for key in params.keys():self.h[key] *= self.decay_rateself.h[key] += (1 - self.decay_rate) * grads[key] * grads[key]params[key] -= self.lr * grads[key] / (np.sqrt(self.h[key]) + 1e-7)

RMSProp的方法和AdaGrad类似,除了每一轮迭代时会将h乘上一个decay_rate(大小在0-1)以减小之前梯度对h的影响

在这里插入图片描述
如图,一开始由于y方向梯度变化大,所以更新快,但因此y方向上学习率也减小较快,使得网络在后期逐渐沿x方向更新

Adam

Adam类似于momentum和AdaGrad两种方法的结合,其具体原理较为复杂,可以找原论文http://arxiv.org/abs/1412.6980v8

class Adam:"""Adam (http://arxiv.org/abs/1412.6980v8)"""def __init__(self, lr=0.001, beta1=0.9, beta2=0.999):self.lr = lrself.beta1 = beta1self.beta2 = beta2self.iter = 0self.m = Noneself.v = Nonedef update(self, params, grads):if self.m is None:self.m, self.v = {}, {}for key, val in params.items():self.m[key] = np.zeros_like(val)self.v[key] = np.zeros_like(val)self.iter += 1lr_t  = self.lr * np.sqrt(1.0 - self.beta2**self.iter) / (1.0 - self.beta1**self.iter)         for key in params.keys():#self.m[key] = self.beta1*self.m[key] + (1-self.beta1)*grads[key]#self.v[key] = self.beta2*self.v[key] + (1-self.beta2)*(grads[key]**2)self.m[key] += (1 - self.beta1) * (grads[key] - self.m[key])self.v[key] += (1 - self.beta2) * (grads[key]**2 - self.v[key])params[key] -= lr_t * self.m[key] / (np.sqrt(self.v[key]) + 1e-7)#unbias_m += (1 - self.beta1) * (grads[key] - self.m[key]) # correct bias#unbisa_b += (1 - self.beta2) * (grads[key]*grads[key] - self.v[key]) # correct bias#params[key] += self.lr * unbias_m / (np.sqrt(unbisa_b) + 1e-7)

利用mnist数据集对几种训练方式进行比较:
在该测试程序中,我们使用5层神经网络,每层神经元个数100。利用SGD, momentum, AdaGrad, Adam, RMSProp分别进行2000次迭代,并比较最终各网络的总损失

# coding: utf-8
import os
import sys
sys.path.append("D:\AI learning source code")  # 为了导入父目录的文件而进行的设定
import matplotlib.pyplot as plt
from dataset.mnist import load_mnist
from common.util import smooth_curve
from common.multi_layer_net import MultiLayerNet
from common.optimizer import *# 0:读入MNIST数据==========
(x_train, t_train), (x_test, t_test) = load_mnist(normalize=True)train_size = x_train.shape[0]
batch_size = 128
max_iterations = 2000# 1:进行实验的设置==========
optimizers = {}
optimizers['SGD'] = SGD()
optimizers['Momentum'] = Momentum()
optimizers['AdaGrad'] = AdaGrad()
optimizers['Adam'] = Adam()
optimizers['RMSprop'] = RMSprop()networks = {}
train_loss = {}
for key in optimizers.keys():networks[key] = MultiLayerNet(input_size=784, hidden_size_list=[100, 100, 100, 100],output_size=10)train_loss[key] = []    # 2:开始训练==========
for i in range(max_iterations):batch_mask = np.random.choice(train_size, batch_size)x_batch = x_train[batch_mask]t_batch = t_train[batch_mask]for key in optimizers.keys():grads = networks[key].gradient(x_batch, t_batch)optimizers[key].update(networks[key].params, grads)loss = networks[key].loss(x_batch, t_batch)train_loss[key].append(loss)if i % 100 == 0:print( "===========" + "iteration:" + str(i) + "===========")for key in optimizers.keys():loss = networks[key].loss(x_batch, t_batch)print(key + ":" + str(loss))# 3.绘制图形==========
markers = {"SGD": "o", "Momentum": "x", "AdaGrad": "s", "Adam": "D", "RMSprop": "v"}
x = np.arange(max_iterations)
for key in optimizers.keys():plt.plot(x, smooth_curve(train_loss[key]), marker=markers[key], markevery=100, label=key)
plt.xlabel("iterations")
plt.ylabel("loss")
plt.ylim(0, 1)
plt.legend()
plt.show()

实验结果如下
在这里插入图片描述

相关文章:

深度学习笔记:不同的反向传播迭代方法

1 随机梯度下降法SGD 随机梯度下降法每次迭代取梯度下降最大的方向更新。这一方法实现简单,但是在很多函数中,梯度下降的方向不一定指向函数最低点,这使得梯度下降呈现“之”字形,其效率较低 class SGD:"""随机…...

ElasticSearch 学习笔记总结(三)

文章目录一、ES 相关名词 专业介绍二、ES 系统架构三、ES 创建分片副本 和 elasticsearch-head插件四、ES 故障转移五、ES 应对故障六、ES 路由计算 和 分片控制七、ES集群 数据写流程八、ES集群 数据读流程九、ES集群 更新流程 和 批量操作十、ES 相关重要 概念 和 名词十一、…...

深入理解border以及应用

深入border属性以及应用&#x1f44f;&#x1f44f; border这个属性在开发过程中很常用&#xff0c;常常用它来作为边界的。但是大家真的了解border吗&#xff1f;以及它的形状是什么样子的。 我们先来看这样一段代码&#xff1a;&#x1f44f; <!--* Author: syk 185901…...

如何复现论文?什么是论文复现?

参考资料&#xff1a; 学习篇—顶会Paper复现方法 - 知乎 如何读论文&#xff1f;复现代码&#xff1f;_复现代码是什么意思 - CSDN 我是如何复现我人生的第一篇论文的 - 知乎 在我看来&#xff0c;论文复现应该有一个大前提和分为两个层次。 大前提是你要清楚地懂得自己要…...

22.2.28打卡 Codeforces Round #851 (Div. 2) A~C

A题 One and Two 题面翻译 题目描述 给你一个数列 a1,a2,…,ana_1, a_2, \ldots, a_na1​,a2​,…,an​ . 数列中的每一个数的值要么是 111 要么是 222 . 找到一个最小的正整数 kkk&#xff0c;使之满足&#xff1a; 1≤k≤n−11 \leq k \leq n-11≤k≤n−1 , anda1⋅a2⋅……...

Learining C++ No.12【vector】

引言&#xff1a; 北京时间&#xff1a;2023/2/27/11:42&#xff0c;高数考试还在进行中&#xff0c;我充分意识到了学校的不高级&#xff0c;因为题目真的没什么意思&#xff0c;虽然挺平易近人&#xff0c;但是……&#xff0c;考试期间时间比较放松&#xff0c;所以不能耽误…...

【数电基础】——逻辑代数运算

目录 1.概念 1.基本逻辑概念 2.基本逻辑电路&#xff08;与或非&#xff09; 逻辑与运算 与门电路&#xff1a; 逻辑或运算 或门电路&#xff1a; ​逻辑非运算&#xff08;逻辑反&#xff09; 非门电路​编辑 3.复合逻辑电路&#xff08;运算&#xff09; 与非逻辑…...

【Redis】什么是缓存与数据库双写不一致?怎么解决?

1. 热点缓存重建 我们以热点缓存 key 重建来一步步引出什么是缓存与数据库双写不一致&#xff0c;及其解决办法。 1.1 什么是热点缓存重建 在实际开发中&#xff0c;开发人员使用 “缓存 过期时间” 的策略来实现加速数据读写和内存使用率&#xff0c;这种策略能满足大多数…...

互联网衰退期,测试工程师35岁之路怎么走...

国内的互联网行业发展较快&#xff0c;所以造成了技术研发类员工工作强度比较大&#xff0c;同时技术的快速更新又需要员工不断的学习新的技术。因此淘汰率也比较高&#xff0c;超过35岁的基层研发类员工&#xff0c;往往因为家庭原因、身体原因&#xff0c;比较难以跟得上工作…...

动态规划(以背包问题为例)

1) 要求达到的目标为装入的背包的总价值最大&#xff0c;并且重量不超出2) 要求装入的物品不能重复动态规划(Dynamic Programming)算法的核心思想是&#xff1a;将大问题划分为小问题进行解决&#xff0c;从而一步步获取最优解的处理算法。动态规划算法与分治算法类似&#xff…...

Java异常

异常的体系结构 在java的Throwable下有Error和Exception两个子类 Error(错误):程序运行中出现了严重的问题,非代码性错误,无法处理,常见的有虚拟机运行错误和内存溢出等Exception(异常):是由于代码本身造成的问题,可以进行处理,异常一个可以分为运行时异常和编译时异常 运行…...

别克GL8改装完工,一起来看看效果

①豪华商务头等舱 别克GL8作为商务车&#xff0c;不管是家用还是商务接待&#xff0c;原车内饰都太掉档次了&#xff0c;所以车主要求全部换掉。>>织布座椅换成航空座椅 主副驾&#xff1a;改装纳帕皮 中排&#xff1a;改装水晶宝座豪华版航空座椅&#xff0c;带通风、加…...

mac 中 shell 一些知识

mac 设置环境变量首先得看你所使用的 shell shell 是一个命令行解释器&#xff0c;顾名思义就是机器外面的一层壳&#xff0c;用于人机交互&#xff0c;只要是人与电脑之间交互的接口&#xff0c;就可以称为 shell。表现为其作用是用户输入一条命令&#xff0c;shell 就立即解…...

CentOS 配置FTP(开启VSFTPD服务)

网上已经有很多关于VSFTPD的配置&#xff0c;但有两个通病&#xff0c;要么就是原理介绍太多&#xff0c;要么就是不完整&#xff0c;操作下来又要查询多篇文章才能用。 我这里不讲原理&#xff0c;只记录操作&#xff0c;尽可能通过复制下面的操作可以实现FTP读写功能。方便自…...

Http的请求方法

Http的请求方法对应的数据传输能力把Http请求分为Url类请求和Body类请求 1.Url类请求包括但不限于GET、HEAD、OPTIONS、TRACE 等请求方法 2.Body类请求包括但不限于POST、PUSH、PATCH、DELETE 等请求方法。 3.原因&#xff1a;get请求没有请求体&#xff08;好像也可以…...

Python字典-- 内附蓝桥题:统计数字

字典 ~~不定时更新&#x1f383;&#xff0c;上次更新&#xff1a;2023/02/28 &#x1f5e1;常用函数&#xff08;方法&#xff09; 1. dic.get(key) --> 判断字典 dic 是否有 key&#xff0c;有返回其对应的值&#xff0c;没有返回 None 举个栗子&#x1f330; dic …...

文本处理工具

Grep工具的基本使用grep作用&#xff1a;grep是行过滤工具&#xff1b;用于根据关键字进行行过滤提示&#xff1a;通过alias命令设置grep别名&#xff0c;搜索参数时带颜色显示alias grepgrep colorauto 命令语法格式&#xff1a;grep [选项] 参数 文件名grep命令选项&#xff…...

C++STL详解(三)——vector的介绍和使用

文章目录vector的介绍vector的使用vector的定义方式vector的空间增长问题reserve和resizevector的迭代器使用begin 和endrbegin和rendinsert 和erasefind函数元素访问vector迭代器失效问题1&#xff1a;inserse插入扩容时空间销毁造成野指针问题2&#xff1a;erase删除或者inse…...

GEBCO海洋数据下载

一、数据集简介 GEBCO&#xff08;General Bathymetric chart of the Oceans&#xff09;旨在为世界海洋提供最权威的、可公开获取的测深数据集。 目前的网格化测深数据集&#xff0c;即GEBCO_2022网格&#xff0c;是一个全球海洋和陆地的地形模型&#xff0c;在15角秒间隔的…...

【C++容器】vector、map、hash_map、unordered_map四大容器的性能分析【2023.02.28】

摘要 vector是标准容器对数组的封装&#xff0c;是一段连续的线性的内存。map底层是二叉排序树。hash_map是C11之前的无序map&#xff0c;unordered_map底层是hash表&#xff0c;涉及桶算法。现对各个容器的查询与”插入“性能做对比分析&#xff0c;方便后期选择。 测试方案…...

[10-3]软件I2C读写MPU6050 江协科技学习笔记(16个知识点)

1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16...

【HTML-16】深入理解HTML中的块元素与行内元素

HTML元素根据其显示特性可以分为两大类&#xff1a;块元素(Block-level Elements)和行内元素(Inline Elements)。理解这两者的区别对于构建良好的网页布局至关重要。本文将全面解析这两种元素的特性、区别以及实际应用场景。 1. 块元素(Block-level Elements) 1.1 基本特性 …...

NLP学习路线图(二十三):长短期记忆网络(LSTM)

在自然语言处理(NLP)领域,我们时刻面临着处理序列数据的核心挑战。无论是理解句子的结构、分析文本的情感,还是实现语言的翻译,都需要模型能够捕捉词语之间依时序产生的复杂依赖关系。传统的神经网络结构在处理这种序列依赖时显得力不从心,而循环神经网络(RNN) 曾被视为…...

08. C#入门系列【类的基本概念】:开启编程世界的奇妙冒险

C#入门系列【类的基本概念】&#xff1a;开启编程世界的奇妙冒险 嘿&#xff0c;各位编程小白探险家&#xff01;欢迎来到 C# 的奇幻大陆&#xff01;今天咱们要深入探索这片大陆上至关重要的 “建筑”—— 类&#xff01;别害怕&#xff0c;跟着我&#xff0c;保准让你轻松搞…...

基于单片机的宠物屋智能系统设计与实现(论文+源码)

本设计基于单片机的宠物屋智能系统核心是实现对宠物生活环境及状态的智能管理。系统以单片机为中枢&#xff0c;连接红外测温传感器&#xff0c;可实时精准捕捉宠物体温变化&#xff0c;以便及时发现健康异常&#xff1b;水位检测传感器时刻监测饮用水余量&#xff0c;防止宠物…...

RushDB开源程序 是现代应用程序和 AI 的即时数据库。建立在 Neo4j 之上

一、软件介绍 文末提供程序和源码下载 RushDB 改变了您处理图形数据的方式 — 不需要 Schema&#xff0c;不需要复杂的查询&#xff0c;只需推送数据即可。 二、Key Features ✨ 主要特点 Instant Setup: Be productive in seconds, not days 即时设置 &#xff1a;在几秒钟…...

手动给中文分词和 直接用神经网络RNN做有什么区别

手动分词和基于神经网络&#xff08;如 RNN&#xff09;的自动分词在原理、实现方式和效果上有显著差异&#xff0c;以下是核心对比&#xff1a; 1. 实现原理对比 对比维度手动分词&#xff08;规则 / 词典驱动&#xff09;神经网络 RNN 分词&#xff08;数据驱动&#xff09…...

基于 HTTP 的单向流式通信协议SSE详解

SSE&#xff08;Server-Sent Events&#xff09;详解 &#x1f9e0; 什么是 SSE&#xff1f; SSE&#xff08;Server-Sent Events&#xff09; 是 HTML5 标准中定义的一种通信机制&#xff0c;它允许服务器主动将事件推送给客户端&#xff08;浏览器&#xff09;。与传统的 H…...

Qt 按钮类控件(Push Button 与 Radio Button)(1)

文章目录 Push Button前提概要API接口给按钮添加图标给按钮添加快捷键 Radio ButtonAPI接口性别选择 Push Button&#xff08;鼠标点击不放连续移动快捷键&#xff09; Radio Button Push Button 前提概要 1. 之前文章中所提到的各种跟QWidget有关的各种属性/函数/方法&#…...

codeforces C. Cool Partition

目录 题目简述&#xff1a; 思路&#xff1a; 总代码&#xff1a; https://codeforces.com/contest/2117/problem/C 题目简述&#xff1a; 给定一个整数数组&#xff0c;现要求你对数组进行分割&#xff0c;但需满足条件&#xff1a;前一个子数组中的值必须在后一个子数组中…...