当前位置: 首页 > news >正文

手撕深度学习中的优化器

深度学习中的优化算法采用的原理是梯度下降法,选取适当的初值params,不断迭代,进行目标函数的极小化,直到收敛。由于负梯度方向时使函数值下降最快的方向,在迭代的每一步,以负梯度方向更新params的值,从而达到减少函数值的目的。

Gradient descent in deep learning

在这里插入图片描述

Optimizer

class Optimizer:"""优化器基类,默认是L2正则化"""def __init__(self, lr, weight_decay):self.lr = lrself.weight_decay = weight_decaydef step(self, grads, params):# 计算当前时刻下降的步长decrement = self.compute_step(grads)if self.weight_decay:decrement += self.weight_decay * params# 更新参数params -= decrementdef compute_step(self, grads):raise NotImplementedError

SGD

随机梯度下降
θt=θ−η⋅gt\theta_t = \theta-\eta \cdot g_t θt=θηgt

  • 每次随机抽取一个batch的样本进行梯度下降

  • 对学习率敏感,太小收敛速度很慢,太大会在极小值附近震荡

  • 对于非凸函数,容易陷入局部最小值或鞍点

class SGD(Optimizer):"""stochastic gradient descent"""def __init__(self, lr=0.1, weight_decay=0.0):super().__init__(lr, weight_decay)def compute_step(self, grads):return self.lr * grads

SGDm

SGD中加入动量(momentum)模拟是物体运动时的惯性,即更新的时候在一定程度上保留之前更新的方向,同时利用当前batch的梯度微调最终的更新方向。这样一来,可以在一定程度上增加稳定性,从而学习地更快,并且还有一定摆脱局部最优的能力。
υt=γυt−1+gtθt=θt−1−ηυt\upsilon_t = \gamma \upsilon_{t-1} + g_t \qquad \theta_t=\theta_{t-1} - \eta \upsilon_t υt=γυt1+gtθt=θt1ηυt

  • gt是当前时刻的梯度,vt是当前时刻参数的下降距离
  • 带动量的小球滚下山坡,可能会错过山谷
class SGDm(Optimizer):"""stochastic gradient descent with momentum"""def __init__(self, lr=0.1, momentum=0.9, weight_decay=0.0):super().__init__(lr, weight_decay)self.momentum = momentumself.beta = 0def compute_step(self, grads):self.beta = self.momentum * self.beta + (1 - self.momentum) * gradsreturn self.lr * self.beta

Adagrad

θt=θt−1−η∑i=0t−1(gi)2gt−1\theta_t=\theta_{t-1} - \frac{\eta}{\sqrt{\sum^{t-1}_{i=0}{(g_i)^2}}}g_{t-1} θt=θt1i=0t1(gi)2ηgt1

  • 自适应调节学习率
  • 对低频的参数做较大的更新,对高频的做较小的更新,也因此,对于稀疏的数据它的表现很好,很好地提高了 SGD 的鲁棒性
  • 缺点是分母梯度的累积,最后梯度消失
class Adagrad(Optimizer):"""Divide the learning rate of each parameter by theroot-mean-square of its previous derivatives"""def __init__(self, lr=0.1, eps=1e-8, weight_decay=0.0):super().__init__(lr, weight_decay)self.eps = epsself.state_sum = 0def compute_step(self, grads):self.state_sum += grads ** 2decrement = grads / (self.state_sum ** 0.5 + self.eps) * self.lrreturn decrement

RMSProp

指数滑动平均更新梯度的平方,为解决Adagrad 梯度急剧下降而提出
υ1=g02υt=αυt−1+(1−α)(gt−1)2\upsilon_1 = g_0^2 \qquad \upsilon_t = \alpha\upsilon_{t-1} + (1-\alpha)(g_{t-1})^2 υ1=g02υt=αυt1+(1α)(gt1)2

θt=θt−1−ηυtgt−1\theta_t=\theta_{t-1} - \frac{\eta}{\sqrt{\upsilon_t}} g_{t-1} θt=θt1υtηgt1

class RMSProp(Optimizer):"""Root Mean Square Prop optimizer"""def __init__(self, lr=0.1, alhpa=0.99, eps=1e-8, weight_decay=0.0):super().__init__(lr, weight_decay)self.eps = epsself.alpha = alhpaself.state_sum = 0def compute_step(self, grads):self.state_sum = self.alpha * self.state_sum + (1 - self.alpha) * grads ** 2decrement = grads / (self.state_sum ** 0.5 + self.eps) * self.lrreturn decrement

Adam

SGDmRMSProp的结合,Adam 算法通过计算梯度的一阶矩估计和二阶矩估计而为不同的参数设计独立的自适应性学习率。

  • SGDm

θt=θt−1−mtmt=β1mt−1+(1−β1)gt−1\theta_t=\theta_{t-1} - m_t \qquad m_t = \beta_1 m_{t-1} + (1-\beta_1)g_{t-1} θt=θt1mtmt=β1mt1+(1β1)gt1

  • RMSProp

θt=θt−1−ηυtgt−1\theta_t=\theta_{t-1} - \frac{\eta}{\sqrt{\upsilon_t}} g_{t-1} θt=θt1υtηgt1

υ1=g02υt=β2υt−1+(1−β2)(gt−1)2\upsilon_1 = g_0^2 \qquad \upsilon_t = \beta_2\upsilon_{t-1} + (1-\beta_2)(g_{t-1})^2 υ1=g02υt=β2υt1+(1β2)(gt1)2

  • Adam

θt=θt−1−ηυt′+εmt′\theta_t = \theta_{t-1} - \frac{\eta}{\sqrt{\upsilon_t'+\varepsilon}} m_t' θt=θt1υt+εηmt

mt′=mt1−β1tvt′=vt1−β2tβ1=0.9β2=0.999m_t' = \frac{m_t}{1-\beta_1^t} \qquad v_t' = \frac{v_t}{1-\beta_2^t} \qquad \beta_1=0.9 \quad \beta_2=0.999 mt=1β1tmtvt=1β2tvtβ1=0.9β2=0.999

class Adam(Optimizer):"""combination of SGDm and RMSProp"""def __init__(self, lr=0.1, betas=(0.9, 0.999), eps=1e-8, weight_decay=0.0):super().__init__(lr, weight_decay)self.eps = epsself.beta1, self.beta2 = betasself.mt = self.vt = 0self._t = 0def compute_step(self, grads):self._t += 1self.mt = self.beta1 * self.mt + (1 - self.beta1) * gradsself.vt = self.beta2 * self.vt + (1 - self.beta2) * (grads ** 2)mt = self.mt / (1 - self.beta1 ** self._t)vt = self.vt / (1 - self.beta2 ** self._t)decrement = mt / (vt ** 0.5 + self.eps) * self.lrreturn decrement

我平时做视觉任务主要用SGDm和Adam两个优化器,感觉带正则化的SGDm的效果非常好,然后调一下学习率和衰减策略


参考资料:

torch.optim — PyTorch documentation
tinynn: A lightweight deep learning library

相关文章:

手撕深度学习中的优化器

深度学习中的优化算法采用的原理是梯度下降法,选取适当的初值params,不断迭代,进行目标函数的极小化,直到收敛。由于负梯度方向时使函数值下降最快的方向,在迭代的每一步,以负梯度方向更新params的值&#…...

英文打字小游戏

目录 1 实验目的 2 实验报告内容 3 实验题目 4 实验环境 5 实验分析和设计思路 6 流程分析和类图结构 ​编辑 7. 实验结果与测试分析 8. 总结 这周没有更新任何的文章,感到十分的抱歉。因为我们老师让我们做一个英文打字的小游戏,并要求撰写实验…...

PCB生产工艺流程三:生产PCB的内层线路有哪7步

PCB生产工艺流程三:生产PCB的内层线路有哪7步 在我们的PCB生产工艺流程的第一步就是内层线路,那么它的流程又有哪些步骤呢?接下来我们就以内层线路的流程为主题,进行详细的分析。 由半固化片和铜箔压合而成,用于…...

算法竞赛进阶指南0x61 最短路

对于一张有向图,我们一般有邻接矩阵和邻接表两种存储方式。对于无向图,可以把无向边看作两条方向相反的有向边,从而采用与有向图一样的存储方式。 $$ 邻接矩阵的空间复杂度为 O(n^2),因此我们一般不采用这种方式 $$ 我们用数组模…...

[学习篇] Autoreleasepool

参考文章: https://www.jianshu.com/p/ec2c854b2efd https://suhou.github.io/2018/01/21/%E5%B8%A6%E7%9D%80%E9%97%AE%E9%A2%98%E7%9C%8B%E6%BA%90%E7%A0%81----%E5%AD%90%E7%BA%BF%E7%A8%8BAutoRelease%E5%AF%B9%E8%B1%A1%E4%BD%95%E6%97%B6%E9%87%8A%E6%94%BE/ …...

晶体基本知识

文章目录晶体基本知识基本概念晶胞<晶格<晶粒<晶体晶胞原子坐标(原子分数坐标)六方晶系与四轴定向七大晶系和十四种点阵结构学习资料吉林大学某实验室教程---知乎系列晶体与压敏器件晶体基本知识 基本概念 晶胞<晶格&#xff1c…...

免费CRM如何进行选择?

如今CRM领域成为炙手可热的赛道,很多CRM系统厂商甚至打出完全免费的口号,是否真的存在完全免费的crm系统?很多企业在免费使用过程中会出现被迫终止的问题,需要花费高价钱才能继续使用,那么,免费crm系统哪个…...

关于金融类iOS套壳上架,我帮你总结了这些经验

首先说明,本文中出现的案例的,没有特别的专门针对谁,只是用于分析,如有觉得不妥的,请及时联系我删除,鉴于本文发出之后,可能造成的一些影响,所以大家看看就好了,千万不要…...

4年功能测试月薪9.5K,3个月时间成功进阶自动化,跳槽涨薪6k后我的路还很长...

前言 其实最开始我并不是互联网从业者,是经历了一场六个月的培训才入的行,这个经历仿佛就是一个遮羞布,不能让任何人知道,就算有面试的时候被问到你是不是被培训的,我还是不能承认这段历史。我是为了生存,…...

python url解码详解

python url解码 url是数据的一个部分,一般会用来做什么呢?比如网站的 URL,比如搜索引擎中的 url,再比如网页中的图片等。 你也许不知道,在 Web页面中的图片、链接、超链接都是 URL,也就是 url。 而如果想要…...

leetcode102:二叉树的层序遍历

给你二叉树的根节点 root ,返回其节点值的 层序遍历 。 (即逐层地,从左到右访问所有节点)。 示例 1: 输入:root [3,9,20,null,null,15,7] 输出:[[3],[9,20],[15,7]] 示例 2: 输入…...

深度学习openMMLab的介绍和使用

文章目录MMCV介绍MMCV的安装修改链接中的cu113修改链接中的torch1.10.0物体分类MMCLS源码下载配置参数解读配置文件的组成如何生成完整配置文件定义自己的数据集构建自己的数据集训练自己的任务物体检测MMDetection语义分割MMSegmentation姿态估计MMPose未完成,持续…...

【vue2】axios请求与axios拦截器的使用详解

🥳博 主:初映CY的前说(前端领域) 🌞个人信条:想要变成得到,中间还有做到! 🤘本文核心:当我们在路由跳转前与后我们可实现触发的操作 【前言】ajax是一种在javaScript代码中发请…...

文件上传都发生了啥

一直在用组件库做文件上传,那里面的原理到底是啥,自己写能不能写一个upload框出来呢? (一)基本原理 浏览器端提供了一个表单,在用户提交请求后,将文件数据和其他表单信息编码并上传至服务器端&#xff0…...

【vim进阶】vim编辑器的多文件操作(如何打开多个文件,如何进行文件间的切换,如何关闭其中的某一个文件)

一、如何打开多个文件? 方法一:启动打开 现在有多个文件 file1 ,file2 , … ,filen. 现在举例打开两个文件 file1,file2 vim file1 file2该方式打开文件,显示屏默认显示第一个文件也就是 file1。 方法二&#xff…...

ToBeWritten之车辆通信

也许每个人出生的时候都以为这世界都是为他一个人而存在的,当他发现自己错的时候,他便开始长大 少走了弯路,也就错过了风景,无论如何,感谢经历 转移发布平台通知:将不再在CSDN博客发布新文章,敬…...

自定义 Jackson 的 ObjectMapper, springboot多个模块共同引用,爽

springboot多个模块共同引用自定义ObjectMapper 🚃统一配置示例自定义 Jackson 的 ObjectMapper更改时区为东八区, 优点是在多个模块中都可以使用同一种方式来进行配置,方便维护和修改 统一配置 假设有一个 Spring Boot 项目,包含多个模块&…...

【面试】Redis面试题

文章目录概述什么是Redis?Redis有哪些优缺点?使用redis有哪些好处?为什么要用 Redis / 为什么要用缓存为什么要用 Redis 而不用 map/guava 做缓存?Redis为什么这么快Redis的应用场景持久化什么是Redis持久化?Redis 的持久化机制是…...

前端后端交互系列之原生Ajax的使用

目录前言一,Ajax概述二,基础知识之Http协议2.1 请求报文2.2 响应报文2.3 如何查看通信报文三,Ajax简单案例3.1 Express框架创建服务端3.2 Ajax案例后台准备3.3 Ajax案例前台准备3.4 发送get请求3.5 发送带有参数的Ajax请求3.6 发送post请求3.…...

openGauss 5.0企业版主从部署,实战狂飙

📢📢📢📣📣📣 哈喽!大家好,我是【IT邦德】,江湖人称jeames007,10余年DBA及大数据工作经验 一位上进心十足的【大数据领域博主】!😜&am…...

使用VSCode开发Django指南

使用VSCode开发Django指南 一、概述 Django 是一个高级 Python 框架,专为快速、安全和可扩展的 Web 开发而设计。Django 包含对 URL 路由、页面模板和数据处理的丰富支持。 本文将创建一个简单的 Django 应用,其中包含三个使用通用基本模板的页面。在此…...

【位运算】消失的两个数字(hard)

消失的两个数字(hard) 题⽬描述:解法(位运算):Java 算法代码:更简便代码 题⽬链接:⾯试题 17.19. 消失的两个数字 题⽬描述: 给定⼀个数组,包含从 1 到 N 所有…...

SCAU期末笔记 - 数据分析与数据挖掘题库解析

这门怎么题库答案不全啊日 来简单学一下子来 一、选择题(可多选) 将原始数据进行集成、变换、维度规约、数值规约是在以下哪个步骤的任务?(C) A. 频繁模式挖掘 B.分类和预测 C.数据预处理 D.数据流挖掘 A. 频繁模式挖掘:专注于发现数据中…...

iPhone密码忘记了办?iPhoneUnlocker,iPhone解锁工具Aiseesoft iPhone Unlocker 高级注册版​分享

平时用 iPhone 的时候,难免会碰到解锁的麻烦事。比如密码忘了、人脸识别 / 指纹识别突然不灵,或者买了二手 iPhone 却被原来的 iCloud 账号锁住,这时候就需要靠谱的解锁工具来帮忙了。Aiseesoft iPhone Unlocker 就是专门解决这些问题的软件&…...

基于Uniapp开发HarmonyOS 5.0旅游应用技术实践

一、技术选型背景 1.跨平台优势 Uniapp采用Vue.js框架,支持"一次开发,多端部署",可同步生成HarmonyOS、iOS、Android等多平台应用。 2.鸿蒙特性融合 HarmonyOS 5.0的分布式能力与原子化服务,为旅游应用带来&#xf…...

2021-03-15 iview一些问题

1.iview 在使用tree组件时,发现没有set类的方法,只有get,那么要改变tree值,只能遍历treeData,递归修改treeData的checked,发现无法更改,原因在于check模式下,子元素的勾选状态跟父节…...

Ascend NPU上适配Step-Audio模型

1 概述 1.1 简述 Step-Audio 是业界首个集语音理解与生成控制一体化的产品级开源实时语音对话系统,支持多语言对话(如 中文,英文,日语),语音情感(如 开心,悲伤)&#x…...

C++课设:简易日历程序(支持传统节假日 + 二十四节气 + 个人纪念日管理)

名人说:路漫漫其修远兮,吾将上下而求索。—— 屈原《离骚》 创作者:Code_流苏(CSDN)(一个喜欢古诗词和编程的Coder😊) 专栏介绍:《编程项目实战》 目录 一、为什么要开发一个日历程序?1. 深入理解时间算法2. 练习面向对象设计3. 学习数据结构应用二、核心算法深度解析…...

Kubernetes 网络模型深度解析:Pod IP 与 Service 的负载均衡机制,Service到底是什么?

Pod IP 的本质与特性 Pod IP 的定位 纯端点地址:Pod IP 是分配给 Pod 网络命名空间的真实 IP 地址(如 10.244.1.2)无特殊名称:在 Kubernetes 中,它通常被称为 “Pod IP” 或 “容器 IP”生命周期:与 Pod …...

华为OD最新机试真题-数组组成的最小数字-OD统一考试(B卷)

题目描述 给定一个整型数组,请从该数组中选择3个元素 组成最小数字并输出 (如果数组长度小于3,则选择数组中所有元素来组成最小数字)。 输入描述 行用半角逗号分割的字符串记录的整型数组,0<数组长度<= 100,0<整数的取值范围<= 10000。 输出描述 由3个元素组成…...