小白学Pytorch系列--Torch.optim API Base class(1)
小白学Pytorch系列–Torch.optim API Base class(1)
torch.optim是一个实现各种优化算法的包。大多数常用的方法都已得到支持,而且接口足够通用,因此将来还可以轻松集成更复杂的方法。
如何使用优化器
使用手torch.optim您必须构造一个优化器对象,该对象将保存当前状态,并将根据计算出的梯度更新参数。
构造它
要构造一个优化器,你必须给它一个包含参数(所有应该是变量)的可迭代对象来优化。然后,您可以指定特定于优化器的选项,如学习率、权重衰减等。
optimizer = optim.SGD(model.parameters(), lr=0.01, momentum=0.9)
optimizer = optim.Adam([var1, var2], lr=0.0001)
每个参数选项
优化器还支持指定每个参数的选项。要做到这一点,不是传递一个变量的迭代对象,而是传递一个dict的迭代对象。它们每个都将定义一个单独的参数组,并且应该包含一个params键,包含属于它的参数列表。其他键应该与优化器接受的关键字参数匹配,并将用作该组的优化选项。
注意: 您仍然可以将选项作为关键字参数传递。在没有覆盖它们的组中,它们将作为默认值使用。当您只想改变一个选项,同时在参数组之间保持所有其他选项一致时,这很有用。
例如,当想要指定每层的学习率时,这是非常有用的
optim.SGD([{'params': model.base.parameters()},{'params': model.classifier.parameters(), 'lr': 1e-3}], lr=1e-2, momentum=0.9)
这意味着这个model.base
参数将使用默认的学习率1e-2
,model.classifier’
参数将使用1e-3
的学习率,所有参数将使用0.9的动量。
进行优化步骤
所有优化器都实现了一个step()
方法,用于更新参数。它有两种用法
optimizer.step()
这是大多数优化器支持的简化版本。该函数可以在梯度计算完成后调用,例如backward()
。
例如:
for input, target in dataset:optimizer.zero_grad()output = model(input)loss = loss_fn(output, target)loss.backward()optimizer.step()
optimizer.step(closure)
一些优化算法(如共轭梯度和LBFGS)需要多次重新计算函数,所以你必须传入一个闭包,允许它们重新计算你的模型。闭包应该清除梯度,计算损失并返回。
for input, target in dataset:def closure():optimizer.zero_grad()output = model(input)loss = loss_fn(output, target)loss.backward()return lossoptimizer.step(closure)
Base class
这部分参考了:https://zhuanlan.zhihu.com/p/87209990
PyTorch 的优化器基本都继承于 “class Optimizer”,这是所有 optimizer 的 base class。
下面是Optimizer的结构
class Optimizer(object):def __init__(self, params, defaults):self.defaults = defaultsself._hook_for_profile()if isinstance(params, torch.Tensor):raise TypeError("params argument given to the optimizer should be ""an iterable of Tensors or dicts, but got " +torch.typename(params))self.state = defaultdict(dict)self.param_groups = []param_groups = list(params)if len(param_groups) == 0:raise ValueError("optimizer got an empty parameter list")if not isinstance(param_groups[0], dict):param_groups = [{'params': param_groups}]for param_group in param_groups:self.add_param_group(param_group)def state_dict(self):...def load_state_dict(self, state_dict):...def cast(param, value):...def zero_grad(self, set_to_none: bool = False):...def step(self, closure):...def add_param_group(self, param_group): ...
init 函数初始化
params
和defaults
是两个重要的参数,defaults定义了全局优化默认值,params定义了模型参数和局部优化默认值。
add_param_group
defaultdict
的作用在于当字典里的 key 被查找但不存在时,返回的不是keyError而是一个默认值,此处
defaultdict(dict)`返回的默认值会是个空字典。最后一行调用的self.add_param_group(param_group),其中param_group是个字典,Key 就是params,Value 就是param_groups = list(params)。
def add_param_group(self, param_group):params = param_group['params']if isinstance(params, torch.Tensor):param_group['params'] = [params]elif isinstance(params, set):raise TypeError('optimizer')else:param_group['params'] = list(params)for param in param_group['params']:if not isinstance(param, torch.Tensor):raise TypeError("optimizer " + torch.typename(param))if not param.is_leaf:raise ValueError("can't optimize a non-leaf Tensor")for name, default in self.defaults.items():if default is required and name not in param_group:raise ValueError("parameter group didn't specify a value of required optimization parameter " +name)else:param_group.setdefault(name, default) # 给参数设置默认参数params = param_group['params']if len(params) != len(set(params)):warnings.warn("optimizer contains ", stacklevel=3)param_set = set()for group in self.param_groups:param_set.update(set(group['params']))if not param_set.isdisjoint(set(param_group['params'])): # 判断两个集合是否包含相同的元素raise ValueError("some parameters appear in more than one parameter group")self.param_groups.append(param_group)
zero_grad
就是将所有参数的梯度置为零p.grad.zero_()。detach_()
的作用是Detaches the Tensor from the graph that created it, making it a leaf. self.param_groups是列表,其中的元素是字典。
def zero_grad(self):r"""Clears the gradients of all optimized :class:`torch.Tensor` s."""for group in self.param_groups:for p in group['params']:if p.grad is not None:p.grad.detach_()p.grad.zero_()
step
更新参数作用, 在父类 Optimizer 的 step
函数中只有一行代码raise NotImplementedError
。网络模型参数和优化器的参数都保存在列表 self.param_groups
的元素中,该元素以字典形式存储和访问具体的网络模型参数和优化器的参数。所以,可以通过两层循环访问网络模型的每一个参数 p
。获取到梯度d_p = p.grad.data
之后,根据优化器参数设置是否使用 momentum或者nesterov再对参数进行调整。最后一行 p.data.add_(-group['lr'], d_p)
的作用是对参数进行更新。state用于保存本次更新是优化器第几轮迭代更新参数。
下面以SGD优化器为例
def step(self, closure=None):loss = Noneif closure is not None:with torch.enable_grad():loss = closure()for group in self.param_groups:params_with_grad = []d_p_list = []momentum_buffer_list = []weight_decay = group['weight_decay']momentum = group['momentum']dampening = group['dampening']nesterov = group['nesterov']maximize = group['maximize']lr = group['lr']for p in group['params']:if p.grad is not None:params_with_grad.append(p)d_p_list.append(p.grad)state = self.state[p]if 'momentum_buffer' not in state:momentum_buffer_list.append(None)else:momentum_buffer_list.append(state['momentum_buffer'])F.sgd(params_with_grad,d_p_list,momentum_buffer_list,weight_decay=weight_decay,momentum=momentum,lr=lr,dampening=dampening,nesterov=nesterov,maximize=maximize,)# update momentum_buffers in statefor p, momentum_buffer in zip(params_with_grad, momentum_buffer_list):state = self.state[p] ## 保存state['momentum_buffer'] = momentum_bufferreturn loss
F.sgd
def sgd(params: List[Tensor],d_p_list: List[Tensor],momentum_buffer_list: List[Optional[Tensor]],*,weight_decay: float,momentum: float,lr: float,dampening: float,nesterov: bool,maximize: bool):for i, param in enumerate(params):d_p = d_p_list[i]if weight_decay != 0:d_p = d_p.add(param, alpha=weight_decay)if momentum != 0:buf = momentum_buffer_list[i]if buf is None:buf = torch.clone(d_p).detach()momentum_buffer_list[i] = bufelse:buf.mul_(momentum).add_(d_p, alpha=1 - dampening)if nesterov:d_p = d_p.add(buf, alpha=momentum)else:d_p = bufalpha = lr if maximize else -lrparam.add_(d_p, alpha=alpha)
SGD上引入了一个Momentum(又叫Heavy Ball)的改进。
load_state_dict
加载优化器状态。
def load_state_dict(self, state_dict):# deepcopy, to be consistent with module APIstate_dict = deepcopy(state_dict)# Validate the state_dictgroups = self.param_groupssaved_groups = state_dict['param_groups']if len(groups) != len(saved_groups):raise ValueError("loaded state dict has a different number of ""parameter groups")param_lens = (len(g['params']) for g in groups)saved_lens = (len(g['params']) for g in saved_groups)if any(p_len != s_len for p_len, s_len in zip(param_lens, saved_lens)):raise ValueError("loaded state dict contains a parameter group ""that doesn't match the size of optimizer's group")# Update the stateid_map = {old_id: p for old_id, p inzip(chain.from_iterable((g['params'] for g in saved_groups)),chain.from_iterable((g['params'] for g in groups)))}def cast(param, value):r"""Make a deep copy of value, casting all tensors to device of param."""if isinstance(value, torch.Tensor):# Floating-point types are a bit special here. They are the only ones# that are assumed to always match the type of params.if param.is_floating_point():value = value.to(param.dtype)value = value.to(param.device)return valueelif isinstance(value, dict):return {k: cast(param, v) for k, v in value.items()}elif isinstance(value, container_abcs.Iterable):return type(value)(cast(param, v) for v in value)else:return value# Copy state assigned to params (and cast tensors to appropriate types).# State that is not assigned to params is copied as is (needed for# backward compatibility).state = defaultdict(dict)for k, v in state_dict['state'].items():if k in id_map:param = id_map[k]state[param] = cast(param, v)else:state[k] = v
state_dict
以字典的形式返回优化器的状态。
def state_dict(self): # Save order indices instead of Tensorsparam_mappings = {}start_index = 0def pack_group(group):nonlocal start_indexpacked = {k: v for k, v in group.items() if k != 'params'}param_mappings.update({id(p): i for i, p in enumerate(group['params'], start_index)if id(p) not in param_mappings})packed['params'] = [param_mappings[id(p)] for p in group['params']]start_index += len(packed['params'])return packedparam_groups = [pack_group(g) for g in self.param_groups]# Remap state to use order indices as keyspacked_state = {(param_mappings[id(k)] if isinstance(k, torch.Tensor) else k): vfor k, v in self.state.items()}return {'state': packed_state,'param_groups': param_groups,}
相关文章:

小白学Pytorch系列--Torch.optim API Base class(1)
小白学Pytorch系列–Torch.optim API Base class(1) torch.optim是一个实现各种优化算法的包。大多数常用的方法都已得到支持,而且接口足够通用,因此将来还可以轻松集成更复杂的方法。 如何使用优化器 使用手torch.optim您必须构造一个优化器对象&…...

flac格式如何转mp3,3招帮你搞定
flac格式如何转mp3,3招帮你搞定的方法来啦。当你的音频是flac格式是不是很头疼,又不知道怎么转mp3 。然后网上搜索出很多方法又不知道从哪个下手,是不是很疑惑?那今天就来看看小编推荐的方法吧,一定让你眼前一亮&#…...

Redis入门到入土(day01)
NoSQL概述 为什么用NoSQL 1、单机MySQL的美好年代 在90年代,一个网站的访问量一般不大,用单个数据库完全可以轻松应付! 在那个时候,更多的都是静态网页,动态交互类型的网站不多。 上述架构下,我们来看看…...

JVM垃圾回收GC 详解(java1.8)
目录 垃圾判断算法(你是不是垃圾?) 引用计数法 可达性算法 对象的引用 强引用 软引用 弱引用 虚引用 对象的自我救赎 垃圾回收算法--分代 标记清除算法 复制算法 标记整理法 垃圾处理器 垃圾判断算法(你是不是垃圾&…...
Mybatis-Plus -03 Mybatis-Plus实现CRUD
Mybatis-Plus实现CRUD 1 Insert增加2 ID生成策略3 Delete删除4 逻辑删除5 Update修改6 Select查询 Mybatis-Plus实现CRUD 通用 CRUD 封装**BaseMapper (opens new window)**接口,为 Mybatis-Plus 启动时自动解析实体表关系映射转换为 Mybatis 内部对象注入容器参数 …...

综合能源系统中基于电转气和碳捕集系统的热电联产建模与优化研究(Matlab代码实现)
💥💥💞💞欢迎来到本博客❤️❤️💥💥 🏆博主优势:🌞🌞🌞博客内容尽量做到思维缜密,逻辑清晰,为了方便读者。 ⛳️座右铭&a…...

“智慧赋能 强链塑链”|工程物资供应链管理中的数字化应用
工程项目中的供应链管理至关重要 工程建设行业是国民经济的重要支柱之一,虽然在总产值上持续保持增长态势,但近年来行业的利润总额增速已连续多年呈现下降趋势。究其原因,可以大体从两个方面来看:一是行业盈利能力出现下降&#x…...

通过docker发布项目
提示:文章写完后,目录可以自动生成,如何生成可参考右边的帮助文档 文章目录 前言例如:docker项目的发布方式 [docker发布的参考链接](https://www.cnblogs.com/emperorking/articles/11244253.html) 一、docker是什么?…...

为什么Spring和IDEA不推荐使用@Autowired注解?
在Spring开发中,Autowired注解是一个常用的依赖注入方式。但是,你可能会惊奇地发现,Spring和IDEA都不推荐使用Autowired注解。关于这个问题,其实答案相对统一,实际上用大白话说起来也容易理解。 官方答案 首先&#…...

windows下运行dpdk下的helloworld
打开“本地安全策略”管理单元,在搜索框输入secpol。 打开本地策略->用户权限分配->锁定内存页->添加用户或组->高级->立即查找 输入电脑用户名,选择并添加。点击确定后,重启电脑。 安装内核驱动,下载地址https://download.csdn.net/download/qq_36314864…...

【AI理论学习】深入理解Prompt Learning和Prompt Tuning
深入理解Prompt Learning和Prompt Tuning 背景Prompt Learning简介1. Prompt是什么?2. 为什么要使用Prompt?3. Prompt Learning的形式(举例)4. 有哪些Pre-training language model?5. 常见的Prompt Learning的方法 Pro…...
从Authy中导出账户和secret
本文转载于我的博客从Authy中导出账户和secret 前言 因为最近买了CanoKey,所以多算试一下CanoKey的TOTP功能,但是之前一直用的Authy并且它默认不支持导出功能 在网上找了一些文档,终于在github上找到了一个有效且简单的方法 目前网上大部分…...
图像锐度评分算法,方差,点锐度法,差分法,梯度法
图像锐度评分算法,方差,点锐度法,差分法,梯度法 图像锐度评分是用来描述图像清晰度的一个指标。常见的图像锐度评分算法包括方差法、点锐度法、差分法和梯度法等。 方差法:该方法是通过计算图像像素值的方差来评估图像…...
查询练习:连接查询
准备用于测试连接查询的数据: CREATE DATABASE testJoin;CREATE TABLE person (id INT,name VARCHAR(20),cardId INT );CREATE TABLE card (id INT,name VARCHAR(20) );INSERT INTO card VALUES (1, 饭卡), (2, 建行卡), (3, 农行卡), (4, 工商卡), (5, 邮政卡); S…...
【mmdeploy】【TODO】使用mmdeploy将mmdetection模型转tensorrt
mmdetection转换 文章目录 mmdetection转换mmdetection 自带转换ONNX——无法测试使用mmdeploy(0.6.0)使用mmdeploy转onnx使用mmdeploy直接转tensorRT调试记录 先上结论:作者最后是转tensorrt的小图才成功的,大图一直不行。文章仅作者自我记录使用&#…...

德赛西威上海车展重磅发布Smart Solution 2.0,有哪些革新点?
4月18日,全球瞩目的第二十届上海车展盛大启幕,作为国际领先的移动出行科技公司,德赛西威携智慧出行黑科技产品矩阵亮相,并以“智出行 共创享”为主题,重磅发布最新迭代的智慧出行解决方案——Smart Solution 2.0。 从…...
戴尔服务器是否需要开启cpupower.service
戴尔并不会默认开启cpupower.service,这取决于具体的操作系统和配置。cpupower.service是一个Linux系统服务,用于管理CPU的功耗和性能调节,可以通过调整CPU的频率和电源管理策略来降低能耗和温度。在某些情况下,开启cpupower.serv…...

day02_第一个Java程序
在开发第一个Java程序之前,我们必须对计算机的一些基础知识进行了解。 常用DOS命令 Java语言的初学者,学习一些DOS命令,会非常有帮助。DOS是一个早期的操作系统,现在已经被Windows系统取代,对于我们开发人员…...
【华为OD机试真题 】1011 - 第K个排列 (JAVA C++ Python JS) | 机试题+算法思路+考点+代码解析
文章目录 一、题目🔸题目描述🔸输入输出🔸样例1🔸样例2二、代码参考🔸C++代码🔸Java代码🔸Python代码🔸JS代码作者:KJ.JK🌈 🌈 🌈 🌈 🌈 🌈 🌈 🌈 🌈 🌈 🌈 🌈 🌈 🍂个人博客首页: KJ.JK 💖系列专栏:...

基于php的校园校园兼职网站的设计与实现
摘要 近年来,信息技术在大学校园中得到了广泛的应用,主要体现在两个方面:一是学校管理系统,包括教务管理、行政管理和分校管理,是我国大学管理和信息传递的主要渠道。二是学生生活服务平台。而随着大学生毕业人数的年…...

stm32G473的flash模式是单bank还是双bank?
今天突然有人stm32G473的flash模式是单bank还是双bank?由于时间太久,我真忘记了。搜搜发现,还真有人和我一样。见下面的链接:https://shequ.stmicroelectronics.cn/forum.php?modviewthread&tid644563 根据STM32G4系列参考手…...

安宝特方案丨XRSOP人员作业标准化管理平台:AR智慧点检验收套件
在选煤厂、化工厂、钢铁厂等过程生产型企业,其生产设备的运行效率和非计划停机对工业制造效益有较大影响。 随着企业自动化和智能化建设的推进,需提前预防假检、错检、漏检,推动智慧生产运维系统数据的流动和现场赋能应用。同时,…...

Nuxt.js 中的路由配置详解
Nuxt.js 通过其内置的路由系统简化了应用的路由配置,使得开发者可以轻松地管理页面导航和 URL 结构。路由配置主要涉及页面组件的组织、动态路由的设置以及路由元信息的配置。 自动路由生成 Nuxt.js 会根据 pages 目录下的文件结构自动生成路由配置。每个文件都会对…...
Spring AI与Spring Modulith核心技术解析
Spring AI核心架构解析 Spring AI(https://spring.io/projects/spring-ai)作为Spring生态中的AI集成框架,其核心设计理念是通过模块化架构降低AI应用的开发复杂度。与Python生态中的LangChain/LlamaIndex等工具类似,但特别为多语…...
区块链技术概述
区块链技术是一种去中心化、分布式账本技术,通过密码学、共识机制和智能合约等核心组件,实现数据不可篡改、透明可追溯的系统。 一、核心技术 1. 去中心化 特点:数据存储在网络中的多个节点(计算机),而非…...
DAY 26 函数专题1
函数定义与参数知识点回顾:1. 函数的定义2. 变量作用域:局部变量和全局变量3. 函数的参数类型:位置参数、默认参数、不定参数4. 传递参数的手段:关键词参数5 题目1:计算圆的面积 任务: 编写一…...
书籍“之“字形打印矩阵(8)0609
题目 给定一个矩阵matrix,按照"之"字形的方式打印这个矩阵,例如: 1 2 3 4 5 6 7 8 9 10 11 12 ”之“字形打印的结果为:1,…...

针对药品仓库的效期管理问题,如何利用WMS系统“破局”
案例: 某医药分销企业,主要经营各类药品的批发与零售。由于药品的特殊性,效期管理至关重要,但该企业一直面临效期问题的困扰。在未使用WMS系统之前,其药品入库、存储、出库等环节的效期管理主要依赖人工记录与检查。库…...

使用VMware克隆功能快速搭建集群
自己搭建的虚拟机,后续不管是学习java还是大数据,都需要集群,java需要分布式的微服务,大数据Hadoop的计算集群,如果从头开始搭建虚拟机会比较费时费力,这里分享一下如何使用克隆功能快速搭建一个集群 先把…...

篇章一 论坛系统——前置知识
目录 1.软件开发 1.1 软件的生命周期 1.2 面向对象 1.3 CS、BS架构 1.CS架构编辑 2.BS架构 1.4 软件需求 1.需求分类 2.需求获取 1.5 需求分析 1. 工作内容 1.6 面向对象分析 1.OOA的任务 2.统一建模语言UML 3. 用例模型 3.1 用例图的元素 3.2 建立用例模型 …...