pytorch使用技巧
pytorch使用技巧
1. 指定GPU编号
设置当前使用的GPU设备仅为0号设备,设备名称为 /gpu:0os.environ["CUDA_VISIBLE_DEVICES"] = "0"
设置当前使用的GPU设备为0, 1号两个设备,名称依次为 /gpu:0、/gpu:1: os.environ["CUDA_VISIBLE_DEVICES"] = "0,1" ,根据顺序表示优先使用0号设备,然后使用1号设备。
2. 查看模型每层输出详情
Keras有一个简洁的API来查看模型的每一层输出尺寸,这在调试网络时非常有用。现在在PyTorch中也可以实现这个功能。
from torchsummary import summarysummary(your_model, input_size=(channels, H, W))
input_size 是根据你自己的网络模型的输入尺寸进行设置。
https://github.com/sksq96/pytorch-summary
3. 梯度裁剪(Gradient Clipping)
import torch.nn as nnoutputs = model(data)loss= loss_fn(outputs, target)optimizer.zero_grad()loss.backward()nn.utils.clip_grad_norm_(model.parameters(), max_norm=20, norm_type=2)optimizer.step()
nn.utils.clip_grad_norm_ 的参数:
-
parameters – 一个基于变量的迭代器,会进行梯度归一化
-
max_norm – 梯度的最大范数
-
norm_type – 规定范数的类型,默认为L2
4. 扩展单张图片维度
因为在训练时的数据维度一般都是 (batch_size, c, h, w),而在测试时只输入一张图片,所以需要扩展维度,扩展维度有多个方法:
import cv2import torchimage = cv2.imread(img_path)image = torch.tensor(image)print(image.size())img = image.view(1, *image.size())print(img.size())# output:# torch.Size([h, w, c])# torch.Size([1, h, w, c])
或
import cv2import numpy as npimage = cv2.imread(img_path)print(image.shape)img = image[np.newaxis, :, :, :]print(img.shape)# output:# (h, w, c)# (1, h, w, c)
import cv2import torchimage = cv2.imread(img_path)image = torch.tensor(image)print(image.size())img = image.unsqueeze(dim=0)print(img.size())img = img.squeeze(dim=0)print(img.size())# output:# torch.Size([(h, w, c)])# torch.Size([1, h, w, c])# torch.Size([h, w, c])
tensor.unsqueeze(dim):扩展维度,dim指定扩展哪个维度。
tensor.squeeze(dim):去除dim指定的且size为1的维度,维度大于1时,squeeze()不起作用,不指定dim时,去除所有size为1的维度。
5. 独热编码
在PyTorch中使用交叉熵损失函数的时候会自动把label转化成onehot,所以不用手动转化,而使用MSE需要手动转化成onehot编码。
import torchclass_num = 8batch_size = 4def one_hot(label):"""将一维列表转换为独热编码"""label = label.resize_(batch_size, 1)m_zeros = torch.zeros(batch_size, class_num)# 从 value 中取值,然后根据 dim 和 index 给相应位置赋值onehot = m_zeros.scatter_(1, label, 1) # (dim,index,value)return onehot.numpy() # Tensor -> Numpylabel = torch.LongTensor(batch_size).random_() % class_num # 对随机数取余print(one_hot(label))# output:[[0. 0. 0. 1. 0. 0. 0. 0.][0. 0. 0. 0. 1. 0. 0. 0.][0. 0. 1. 0. 0. 0. 0. 0.][0. 1. 0. 0. 0. 0. 0. 0.]]
6. 防止验证模型时爆显存
验证模型时不需要求导,即不需要梯度计算,关闭autograd,可以提高速度,节约内存。如果不关闭可能会爆显存。
with torch.no_grad():# 使用model进行预测的代码pass
意思就是PyTorch的缓存分配器会事先分配一些固定的显存,即使实际上tensors并没有使用完这些显存,这些显存也不能被其他应用使用。这个分配过程由第一次CUDA内存访问触发的。
而 torch.cuda.empty_cache() 的作用就是释放缓存分配器当前持有的且未占用的缓存显存,以便这些显存可以被其他GPU应用程序中使用,并且通过 nvidia-smi命令可见。注意使用此命令不会释放tensors占用的显存。
对于不用的数据变量,Pytorch 可以自动进行回收从而释放相应的显存。
7. 学习率衰减
import torch.optim as optimfrom torch.optim import lr_scheduler# 训练前的初始化optimizer = optim.Adam(net.parameters(), lr=0.001)scheduler = lr_scheduler.StepLR(optimizer, 10, 0.1) # # 每过10个epoch,学习率乘以0.1# 训练过程中for n in n_epoch:scheduler.step()...
8. 冻结某些层的参数
参考:Pytorch 冻结预训练模型的某一层
https://www.zhihu.com/question/311095447/answer/589307812
在加载预训练模型的时候,我们有时想冻结前面几层,使其参数在训练过程中不发生变化。
我们需要先知道每一层的名字,通过如下代码打印:
net = Network() # 获取自定义网络结构for name, value in net.named_parameters():print('name: {0},\t grad: {1}'.format(name, value.requires_grad))
name: cnn.VGG_16.convolution1_1.weight, grad: Truename: cnn.VGG_16.convolution1_1.bias, grad: Truename: cnn.VGG_16.convolution1_2.weight, grad: Truename: cnn.VGG_16.convolution1_2.bias, grad: Truename: cnn.VGG_16.convolution2_1.weight, grad: Truename: cnn.VGG_16.convolution2_1.bias, grad: Truename: cnn.VGG_16.convolution2_2.weight, grad: Truename: cnn.VGG_16.convolution2_2.bias, grad: True
后面的True表示该层的参数可训练,然后我们定义一个要冻结的层的列表:
no_grad = ['cnn.VGG_16.convolution1_1.weight','cnn.VGG_16.convolution1_1.bias','cnn.VGG_16.convolution1_2.weight','cnn.VGG_16.convolution1_2.bias']
net = Net.CTPN() # 获取网络结构for name, value in net.named_parameters():if name in no_grad:value.requires_grad = Falseelse:value.requires_grad = True
name: cnn.VGG_16.convolution1_1.weight, grad: Falsename: cnn.VGG_16.convolution1_1.bias, grad: Falsename: cnn.VGG_16.convolution1_2.weight, grad: Falsename: cnn.VGG_16.convolution1_2.bias, grad: Falsename: cnn.VGG_16.convolution2_1.weight, grad: Truename: cnn.VGG_16.convolution2_1.bias, grad: Truename: cnn.VGG_16.convolution2_2.weight, grad: Truename: cnn.VGG_16.convolution2_2.bias, grad: True
可以看到前两层的weight和bias的requires_grad都为False,表示它们不可训练。
最后在定义优化器时,只对requires_grad为True的层的参数进行更新。
optimizer = optim.Adam(filter(lambda p: p.requires_grad, net.parameters()), lr=0.01)
9. 对不同层使用不同学习率
net = Network() # 获取自定义网络结构for name, value in net.named_parameters():print('name: {}'.format(name))# 输出:# name: cnn.VGG_16.convolution1_1.weight# name: cnn.VGG_16.convolution1_1.bias# name: cnn.VGG_16.convolution1_2.weight# name: cnn.VGG_16.convolution1_2.bias# name: cnn.VGG_16.convolution2_1.weight# name: cnn.VGG_16.convolution2_1.bias# name: cnn.VGG_16.convolution2_2.weight# name: cnn.VGG_16.convolution2_2.bias
对 convolution1 和 convolution2 设置不同的学习率,首先将它们分开,即放到不同的列表里:
conv1_params = []conv2_params = []for name, parms in net.named_parameters():if "convolution1" in name:conv1_params += [parms]else:conv2_params += [parms]# 然后在优化器中进行如下操作:optimizer = optim.Adam([{"params": conv1_params, 'lr': 0.01},{"params": conv2_params, 'lr': 0.001},],weight_decay=1e-3,)
我们将模型划分为两部分,存放到一个列表里,每部分就对应上面的一个字典,在字典里设置不同的学习率。当这两部分有相同的其他参数时,就将该参数放到列表外面作为全局参数,如上面的`weight_decay`。
也可以在列表外设置一个全局学习率,当各部分字典里设置了局部学习率时,就使用该学习率,否则就使用列表外的全局学习率。
相关文章:
pytorch使用技巧
pytorch使用技巧 1. 指定GPU编号 设置当前使用的GPU设备仅为0号设备,设备名称为 /gpu:0os.environ["CUDA_VISIBLE_DEVICES"] "0" 设置当前使用的GPU设备为0, 1号两个设备,名称依次为 /gpu:0、/gpu:1: os.environ[&quo…...
从用户数据到区块链:Facebook如何利用去中心化技术
在数字化时代,用户数据的管理和保护已成为科技公司面临的重大挑战。作为全球最大的社交网络平台之一,Facebook不仅在用户数据的处理上积累了丰富的经验,也在探索如何利用去中心化技术,如区块链,来改进其数据管理和用户…...
Elasticsearch之bool查询
bool 查询是 Elasticsearch 中最常用的复合查询类型,允许将多个查询组合在一起。它通过逻辑操作符(如 must、should、must_not 和 filter)来构建复杂的查询条件,从而满足多条件匹配、逻辑与(AND)、或&#…...
IntelliJ IDEA 创建 Java 项目指南
IntelliJ IDEA 是一款功能强大的集成开发环境(IDE),广泛用于 Java 开发。本文将介绍如何在 IntelliJ IDEA 中创建一个新的 Java 项目,包括环境的设置和基本配置。更多问题,请查阅 一、安装 IntelliJ IDEA 1. 下载 In…...
一起学Java(13)-[日志篇]教你分析SLF4J和Log4j2源码,掌握SLF4J与Log4j2桥接集成原理
研究完SLF4J和Logback这种无缝集成的方式(一起学Java(12)-[日志篇]教你分析SLF4J源码,掌握SLF4J如何与Logback无缝集成的原理),继续研究Log4j2和SLF4J这种需要桥接集成的方式。 一、桥接包如何与SLF4J集成 我们已经知道SLF4J利用ServiceLoader机制&…...
深入Redis:核心的缓存
Redis最主要的用途,主要有三个方面:存储数据、缓存、消息队列。 其中,缓存是Redis最常用的场景。Redis使用内存作为硬盘的缓存。把用户集中访问的20%数据放到缓存中去,可以应对80%的请求。 数据库是非常重要的组件,但…...
集群聊天服务器项目【C++】项目介绍和环境搭建
前言:学习一个基于C集群聊天服务器的项目,记录学习的内容和学习的过程。 1.项目介绍 在 Linux 环境下基于 muduo 开发的集群聊天服务器。实现新用户注册、用户登录、添加好友、添加群组、好友通信、群组聊天、保持离线消息等功能。 2.技术栈 Json序列…...
c++ #include <memory> 智能指针介绍
#include <memory> 是 C 标准库中的头文件,用于支持智能指针的功能。智能指针是现代 C 的一种资源管理工具,用于自动管理动态分配的内存,从而减少内存泄漏和悬挂指针等问题的发生。它提供了多种类型的智能指针,包括 std::un…...
32.递归、搜索、回溯之floodfill算法
0.简介 1.图像渲染 . - 力扣(LeetCode) 题目解析 算法原理 代码 class Solution {int[] dx { 0, 0, 1, -1 };int[] dy { 1, -1, 0, 0 };int m, n;int prev;public int[][] floodFill(int[][] image, int sr, int sc, int color) {if (image[sr][sc]…...
Vue3.5+ 响应式 Props 解构
你好同学,我是沐爸,欢迎点赞、收藏、评论和关注。 在 Vue 3.5 中,响应式 Props 解构已经稳定并默认启用。这意味着在 <script setup> 中从 defineProps 调用解构的变量现在是响应式的。这一改进大大简化了声明带有默认值的 props 的方…...
k8s中的认证授权
目录 一、kubernetes API 访问控制 1.1 UserAccount与ServiceAccount 1.1.1 ServiceAccount 1.1.2 ServiceAccount示例 二、认证(在k8s中建立认证用户) 2.1 创建UserAccount 2.2 RBAC(Role Based Access Control) 2.2.1 基于角色访问控制授权&…...
Leetcode 3291. Minimum Number of Valid Strings to Form Target I
Leetcode 3291. Minimum Number of Valid Strings to Form Target I 1. 解题思路2. 代码实现 题目链接:3291. Minimum Number of Valid Strings to Form Target I 1. 解题思路 这一题第一反应就是用一个字典树动态规划的方式,倒是也搞定了,…...
PostgreSQL的查看主从同步状态
PostgreSQL的查看主从同步状态 PostgreSQL 提供了一些系统视图和函数,查看和监控主从同步的状态。 1 在主节点上查看同步状态 pg_stat_replication 视图 在主节点上,可以通过查询 pg_stat_replication 视图来查看复制的详细状态信息,包括…...
Java多态性的理解
方法的覆盖 子类的方法重写了父类的方法,相当于对原来的方法进行了增强,接口就是这样的思想。 属性的隔离(Java中什么情况下都不会属性覆盖,python可能会覆盖) public class Main {public static void main(String[…...
安全工具 | 使用Burp Suite的10个小tips
Burp Suite 应用程序中有用功能的集合 img Burp Suite 是一款出色的分析工具,用于测试 Web 应用程序和系统的安全漏洞。它有很多很棒的功能可以在渗透测试中使用。您使用它的次数越多,您就越发现它的便利功能。 本文内容是我在测试期间学到并经常的主要…...
企业项目中字符串工具类
此工具类暂时包含如下功能: isEmpty()判断字符串是否为空subSpecifiedString()判断字符串是否超出指定长度,超出则截取到指定长度yearMonthToDate()将年月的字符串转成年月日格式 yearMonthToDateTime()将年月的字符串转成年月日时分秒格式 package co…...
下载github patch到本地
以下是几种从 GitHub 上下载以.patch 结尾的补丁文件的方法: 通过浏览器直接下载 打开包含该.patch 文件的 GitHub 仓库。在仓库的文件列表中找到对应的.patch 文件。点击该文件,浏览器会显示文件的内容,在页面的右上角通常会有一个“Raw”…...
C++基础部分代码
C OOP面对对象 this指针 C:各种各样的函数定义 struct C:类》实体的抽象类型 实体(属性,行为)-》ADT(abstract data type) OOP语言的四大特征是什么? 抽象 封装/隐藏 继承 多态 访问限定符:public公有的 private私有的…...
neo4j(spring) 使用示例
文章目录 前言一、neo4j是什么二、开始编码1. yml 配置2. crud 测试3. node relation 与java中对象的关系4. 编码测试 总结 前言 图数据库先驱者 neo4j:neo4j官网地址 可以选择桌面版安装等多种方式,我这里采用的是docker安装 直接执行docker安装命令: docker run…...
链接升级:Element UI <el-link> 的应用
链接升级:Element UI 的应用 一 . 创建文字链接1.1 注册路由1.2 创建文字链接 二 . 文字链接的属性2.1 文字链接的颜色2.2 是否显示下划线2.3 是否禁用状态2.4 填写跳转地址2.5 加入图标 在本篇文章中,我们将深入探索Element UI中的<el-link>组件—…...
3个关键功能解析:USBToolBox如何简化macOS与Windows的USB端口映射难题
3个关键功能解析:USBToolBox如何简化macOS与Windows的USB端口映射难题 【免费下载链接】tool the USBToolBox tool 项目地址: https://gitcode.com/gh_mirrors/too/tool 在Hackintosh和跨平台开发领域,USB端口映射一直是个令人头疼的技术难题。US…...
DeepSeek系统设计辅助效能断崖式下降的3个信号,第2个90%工程师至今未察觉!
更多请点击: https://kaifayun.com 第一章:DeepSeek系统设计辅助效能断崖式下降的3个信号,第2个90%工程师至今未察觉! 当 DeepSeek 的系统设计辅助能力突然变“笨”——接口建议频繁失准、上下文感知错乱、生成代码无法通过基础编…...
环境光遮蔽(Ambient Occlusion):揭秘那个让虚拟世界“有重量感“的阴影魔法
一、一个让我"开窍"的老木匠故事 我有个朋友是传统家具的修复师,他给我讲过一个让我至今难忘的故事。他说他刚入行时跟着一位 70 多岁的老木匠师父学习——师父让他做的第一件事不是雕花、不是榫卯——而是"看阴影"——这个看似奇怪的训练改变了…...
飞书远程控机:OpenClaw配置全攻略
本文详细介绍如何通过 OpenClaw 工具对接飞书开放平台,配置智能机器人实现 Windows 电脑的远程控制。主要内容涵盖文件管理和程序启动等核心功能的实现方法,并提供完整的配置指南与常见问题解决方案。 一、使用前提说明 1. 系统要求 仅适用于 Windows…...
SAP-ABAP:变量、常量、结构与内表声明(10篇博客合集) 第五篇:声明时的键值设计技巧:结构与内表的主键、非主键配置指南
变量、常量、结构与内表声明(10篇博客合集) 第五篇:声明时的键值设计技巧:结构与内表的主键、非主键配置指南如果把内表比作一张内存中的“数据库表”,那么键就是这张表的索引甚至主键。键的设计直接决定了数据的唯一性…...
为什么92%的DeepSeek二次开发团队在6个月内遭遇交付延迟?——基于17个真实项目的技术债务归因分析
更多请点击: https://intelliparadigm.com 第一章:为什么92%的DeepSeek二次开发团队在6个月内遭遇交付延迟?——基于17个真实项目的技术债务归因分析 在对17个采用DeepSeek-R1/VL模型开展定制化开发的工业级项目进行回溯审计后,我…...
DeepSeek-R1补全能力封测倒计时(仅剩72小时开放API灰度权限):这份内部测试SOP已被3家头部科技公司紧急采购
更多请点击: https://intelliparadigm.com 第一章:DeepSeek-R1代码补全能力封测全景概览 DeepSeek-R1 是深度求索(DeepSeek)推出的高性能开源推理模型,在代码补全场景中展现出显著的上下文理解力与多语言泛化能力。本…...
中兴光猫终极管理指南:解锁工厂模式与Telnet权限的实战教程
中兴光猫终极管理指南:解锁工厂模式与Telnet权限的实战教程 【免费下载链接】zteOnu A tool that can open ZTE onu device factory mode 项目地址: https://gitcode.com/gh_mirrors/zt/zteOnu 掌握中兴光猫的设备管理和权限获取能力是网络管理员和技术爱好者…...
为什么鸿蒙 App 最终都会走向状态驱动?
子玥酱 (掘金 / 知乎 / CSDN / 简书 同名) 大家好,我是 子玥酱,一名长期深耕在一线的前端程序媛 👩💻。曾就职于多家知名互联网大厂,目前在某国企负责前端软件研发相关工作,主要聚…...
基于Arduino与nRF24L01+的无线传感器平台设计与部署指南
1. 项目概述与设计思路如果你和我一样,喜欢在阳台或者小院子里种点蔬菜瓜果,那你肯定也遇到过这样的烦恼:出门几天,心里总惦记着家里的番茄苗是不是缺水了,小温室里的温度会不会太高。传统的温湿度计只能让你在现场读数…...
