pytorch使用技巧
pytorch使用技巧
1. 指定GPU编号
设置当前使用的GPU设备仅为0号设备,设备名称为 /gpu:0os.environ["CUDA_VISIBLE_DEVICES"] = "0"
设置当前使用的GPU设备为0, 1号两个设备,名称依次为 /gpu:0、/gpu:1: os.environ["CUDA_VISIBLE_DEVICES"] = "0,1" ,根据顺序表示优先使用0号设备,然后使用1号设备。
2. 查看模型每层输出详情
Keras有一个简洁的API来查看模型的每一层输出尺寸,这在调试网络时非常有用。现在在PyTorch中也可以实现这个功能。
from torchsummary import summarysummary(your_model, input_size=(channels, H, W))
input_size 是根据你自己的网络模型的输入尺寸进行设置。
https://github.com/sksq96/pytorch-summary
3. 梯度裁剪(Gradient Clipping)
import torch.nn as nnoutputs = model(data)loss= loss_fn(outputs, target)optimizer.zero_grad()loss.backward()nn.utils.clip_grad_norm_(model.parameters(), max_norm=20, norm_type=2)optimizer.step()
nn.utils.clip_grad_norm_ 的参数:
-
parameters – 一个基于变量的迭代器,会进行梯度归一化
-
max_norm – 梯度的最大范数
-
norm_type – 规定范数的类型,默认为L2
4. 扩展单张图片维度
因为在训练时的数据维度一般都是 (batch_size, c, h, w),而在测试时只输入一张图片,所以需要扩展维度,扩展维度有多个方法:
import cv2import torchimage = cv2.imread(img_path)image = torch.tensor(image)print(image.size())img = image.view(1, *image.size())print(img.size())# output:# torch.Size([h, w, c])# torch.Size([1, h, w, c])
或
import cv2import numpy as npimage = cv2.imread(img_path)print(image.shape)img = image[np.newaxis, :, :, :]print(img.shape)# output:# (h, w, c)# (1, h, w, c)
import cv2import torchimage = cv2.imread(img_path)image = torch.tensor(image)print(image.size())img = image.unsqueeze(dim=0)print(img.size())img = img.squeeze(dim=0)print(img.size())# output:# torch.Size([(h, w, c)])# torch.Size([1, h, w, c])# torch.Size([h, w, c])
tensor.unsqueeze(dim):扩展维度,dim指定扩展哪个维度。
tensor.squeeze(dim):去除dim指定的且size为1的维度,维度大于1时,squeeze()不起作用,不指定dim时,去除所有size为1的维度。
5. 独热编码
在PyTorch中使用交叉熵损失函数的时候会自动把label转化成onehot,所以不用手动转化,而使用MSE需要手动转化成onehot编码。
import torchclass_num = 8batch_size = 4def one_hot(label):"""将一维列表转换为独热编码"""label = label.resize_(batch_size, 1)m_zeros = torch.zeros(batch_size, class_num)# 从 value 中取值,然后根据 dim 和 index 给相应位置赋值onehot = m_zeros.scatter_(1, label, 1) # (dim,index,value)return onehot.numpy() # Tensor -> Numpylabel = torch.LongTensor(batch_size).random_() % class_num # 对随机数取余print(one_hot(label))# output:[[0. 0. 0. 1. 0. 0. 0. 0.][0. 0. 0. 0. 1. 0. 0. 0.][0. 0. 1. 0. 0. 0. 0. 0.][0. 1. 0. 0. 0. 0. 0. 0.]]
6. 防止验证模型时爆显存
验证模型时不需要求导,即不需要梯度计算,关闭autograd,可以提高速度,节约内存。如果不关闭可能会爆显存。
with torch.no_grad():# 使用model进行预测的代码pass
意思就是PyTorch的缓存分配器会事先分配一些固定的显存,即使实际上tensors并没有使用完这些显存,这些显存也不能被其他应用使用。这个分配过程由第一次CUDA内存访问触发的。
而 torch.cuda.empty_cache() 的作用就是释放缓存分配器当前持有的且未占用的缓存显存,以便这些显存可以被其他GPU应用程序中使用,并且通过 nvidia-smi命令可见。注意使用此命令不会释放tensors占用的显存。
对于不用的数据变量,Pytorch 可以自动进行回收从而释放相应的显存。
7. 学习率衰减
import torch.optim as optimfrom torch.optim import lr_scheduler# 训练前的初始化optimizer = optim.Adam(net.parameters(), lr=0.001)scheduler = lr_scheduler.StepLR(optimizer, 10, 0.1) # # 每过10个epoch,学习率乘以0.1# 训练过程中for n in n_epoch:scheduler.step()...
8. 冻结某些层的参数
参考:Pytorch 冻结预训练模型的某一层
https://www.zhihu.com/question/311095447/answer/589307812
在加载预训练模型的时候,我们有时想冻结前面几层,使其参数在训练过程中不发生变化。
我们需要先知道每一层的名字,通过如下代码打印:
net = Network() # 获取自定义网络结构for name, value in net.named_parameters():print('name: {0},\t grad: {1}'.format(name, value.requires_grad))
name: cnn.VGG_16.convolution1_1.weight, grad: Truename: cnn.VGG_16.convolution1_1.bias, grad: Truename: cnn.VGG_16.convolution1_2.weight, grad: Truename: cnn.VGG_16.convolution1_2.bias, grad: Truename: cnn.VGG_16.convolution2_1.weight, grad: Truename: cnn.VGG_16.convolution2_1.bias, grad: Truename: cnn.VGG_16.convolution2_2.weight, grad: Truename: cnn.VGG_16.convolution2_2.bias, grad: True
后面的True表示该层的参数可训练,然后我们定义一个要冻结的层的列表:
no_grad = ['cnn.VGG_16.convolution1_1.weight','cnn.VGG_16.convolution1_1.bias','cnn.VGG_16.convolution1_2.weight','cnn.VGG_16.convolution1_2.bias']
net = Net.CTPN() # 获取网络结构for name, value in net.named_parameters():if name in no_grad:value.requires_grad = Falseelse:value.requires_grad = True
name: cnn.VGG_16.convolution1_1.weight, grad: Falsename: cnn.VGG_16.convolution1_1.bias, grad: Falsename: cnn.VGG_16.convolution1_2.weight, grad: Falsename: cnn.VGG_16.convolution1_2.bias, grad: Falsename: cnn.VGG_16.convolution2_1.weight, grad: Truename: cnn.VGG_16.convolution2_1.bias, grad: Truename: cnn.VGG_16.convolution2_2.weight, grad: Truename: cnn.VGG_16.convolution2_2.bias, grad: True
可以看到前两层的weight和bias的requires_grad都为False,表示它们不可训练。
最后在定义优化器时,只对requires_grad为True的层的参数进行更新。
optimizer = optim.Adam(filter(lambda p: p.requires_grad, net.parameters()), lr=0.01)
9. 对不同层使用不同学习率
net = Network() # 获取自定义网络结构for name, value in net.named_parameters():print('name: {}'.format(name))# 输出:# name: cnn.VGG_16.convolution1_1.weight# name: cnn.VGG_16.convolution1_1.bias# name: cnn.VGG_16.convolution1_2.weight# name: cnn.VGG_16.convolution1_2.bias# name: cnn.VGG_16.convolution2_1.weight# name: cnn.VGG_16.convolution2_1.bias# name: cnn.VGG_16.convolution2_2.weight# name: cnn.VGG_16.convolution2_2.bias
对 convolution1 和 convolution2 设置不同的学习率,首先将它们分开,即放到不同的列表里:
conv1_params = []conv2_params = []for name, parms in net.named_parameters():if "convolution1" in name:conv1_params += [parms]else:conv2_params += [parms]# 然后在优化器中进行如下操作:optimizer = optim.Adam([{"params": conv1_params, 'lr': 0.01},{"params": conv2_params, 'lr': 0.001},],weight_decay=1e-3,)
我们将模型划分为两部分,存放到一个列表里,每部分就对应上面的一个字典,在字典里设置不同的学习率。当这两部分有相同的其他参数时,就将该参数放到列表外面作为全局参数,如上面的`weight_decay`。
也可以在列表外设置一个全局学习率,当各部分字典里设置了局部学习率时,就使用该学习率,否则就使用列表外的全局学习率。
相关文章:
pytorch使用技巧
pytorch使用技巧 1. 指定GPU编号 设置当前使用的GPU设备仅为0号设备,设备名称为 /gpu:0os.environ["CUDA_VISIBLE_DEVICES"] "0" 设置当前使用的GPU设备为0, 1号两个设备,名称依次为 /gpu:0、/gpu:1: os.environ[&quo…...
从用户数据到区块链:Facebook如何利用去中心化技术
在数字化时代,用户数据的管理和保护已成为科技公司面临的重大挑战。作为全球最大的社交网络平台之一,Facebook不仅在用户数据的处理上积累了丰富的经验,也在探索如何利用去中心化技术,如区块链,来改进其数据管理和用户…...
Elasticsearch之bool查询
bool 查询是 Elasticsearch 中最常用的复合查询类型,允许将多个查询组合在一起。它通过逻辑操作符(如 must、should、must_not 和 filter)来构建复杂的查询条件,从而满足多条件匹配、逻辑与(AND)、或&#…...
IntelliJ IDEA 创建 Java 项目指南
IntelliJ IDEA 是一款功能强大的集成开发环境(IDE),广泛用于 Java 开发。本文将介绍如何在 IntelliJ IDEA 中创建一个新的 Java 项目,包括环境的设置和基本配置。更多问题,请查阅 一、安装 IntelliJ IDEA 1. 下载 In…...
一起学Java(13)-[日志篇]教你分析SLF4J和Log4j2源码,掌握SLF4J与Log4j2桥接集成原理
研究完SLF4J和Logback这种无缝集成的方式(一起学Java(12)-[日志篇]教你分析SLF4J源码,掌握SLF4J如何与Logback无缝集成的原理),继续研究Log4j2和SLF4J这种需要桥接集成的方式。 一、桥接包如何与SLF4J集成 我们已经知道SLF4J利用ServiceLoader机制&…...
深入Redis:核心的缓存
Redis最主要的用途,主要有三个方面:存储数据、缓存、消息队列。 其中,缓存是Redis最常用的场景。Redis使用内存作为硬盘的缓存。把用户集中访问的20%数据放到缓存中去,可以应对80%的请求。 数据库是非常重要的组件,但…...
集群聊天服务器项目【C++】项目介绍和环境搭建
前言:学习一个基于C集群聊天服务器的项目,记录学习的内容和学习的过程。 1.项目介绍 在 Linux 环境下基于 muduo 开发的集群聊天服务器。实现新用户注册、用户登录、添加好友、添加群组、好友通信、群组聊天、保持离线消息等功能。 2.技术栈 Json序列…...
c++ #include <memory> 智能指针介绍
#include <memory> 是 C 标准库中的头文件,用于支持智能指针的功能。智能指针是现代 C 的一种资源管理工具,用于自动管理动态分配的内存,从而减少内存泄漏和悬挂指针等问题的发生。它提供了多种类型的智能指针,包括 std::un…...
32.递归、搜索、回溯之floodfill算法
0.简介 1.图像渲染 . - 力扣(LeetCode) 题目解析 算法原理 代码 class Solution {int[] dx { 0, 0, 1, -1 };int[] dy { 1, -1, 0, 0 };int m, n;int prev;public int[][] floodFill(int[][] image, int sr, int sc, int color) {if (image[sr][sc]…...
Vue3.5+ 响应式 Props 解构
你好同学,我是沐爸,欢迎点赞、收藏、评论和关注。 在 Vue 3.5 中,响应式 Props 解构已经稳定并默认启用。这意味着在 <script setup> 中从 defineProps 调用解构的变量现在是响应式的。这一改进大大简化了声明带有默认值的 props 的方…...
k8s中的认证授权
目录 一、kubernetes API 访问控制 1.1 UserAccount与ServiceAccount 1.1.1 ServiceAccount 1.1.2 ServiceAccount示例 二、认证(在k8s中建立认证用户) 2.1 创建UserAccount 2.2 RBAC(Role Based Access Control) 2.2.1 基于角色访问控制授权&…...
Leetcode 3291. Minimum Number of Valid Strings to Form Target I
Leetcode 3291. Minimum Number of Valid Strings to Form Target I 1. 解题思路2. 代码实现 题目链接:3291. Minimum Number of Valid Strings to Form Target I 1. 解题思路 这一题第一反应就是用一个字典树动态规划的方式,倒是也搞定了,…...
PostgreSQL的查看主从同步状态
PostgreSQL的查看主从同步状态 PostgreSQL 提供了一些系统视图和函数,查看和监控主从同步的状态。 1 在主节点上查看同步状态 pg_stat_replication 视图 在主节点上,可以通过查询 pg_stat_replication 视图来查看复制的详细状态信息,包括…...
Java多态性的理解
方法的覆盖 子类的方法重写了父类的方法,相当于对原来的方法进行了增强,接口就是这样的思想。 属性的隔离(Java中什么情况下都不会属性覆盖,python可能会覆盖) public class Main {public static void main(String[…...
安全工具 | 使用Burp Suite的10个小tips
Burp Suite 应用程序中有用功能的集合 img Burp Suite 是一款出色的分析工具,用于测试 Web 应用程序和系统的安全漏洞。它有很多很棒的功能可以在渗透测试中使用。您使用它的次数越多,您就越发现它的便利功能。 本文内容是我在测试期间学到并经常的主要…...
企业项目中字符串工具类
此工具类暂时包含如下功能: isEmpty()判断字符串是否为空subSpecifiedString()判断字符串是否超出指定长度,超出则截取到指定长度yearMonthToDate()将年月的字符串转成年月日格式 yearMonthToDateTime()将年月的字符串转成年月日时分秒格式 package co…...
下载github patch到本地
以下是几种从 GitHub 上下载以.patch 结尾的补丁文件的方法: 通过浏览器直接下载 打开包含该.patch 文件的 GitHub 仓库。在仓库的文件列表中找到对应的.patch 文件。点击该文件,浏览器会显示文件的内容,在页面的右上角通常会有一个“Raw”…...
C++基础部分代码
C OOP面对对象 this指针 C:各种各样的函数定义 struct C:类》实体的抽象类型 实体(属性,行为)-》ADT(abstract data type) OOP语言的四大特征是什么? 抽象 封装/隐藏 继承 多态 访问限定符:public公有的 private私有的…...
neo4j(spring) 使用示例
文章目录 前言一、neo4j是什么二、开始编码1. yml 配置2. crud 测试3. node relation 与java中对象的关系4. 编码测试 总结 前言 图数据库先驱者 neo4j:neo4j官网地址 可以选择桌面版安装等多种方式,我这里采用的是docker安装 直接执行docker安装命令: docker run…...
链接升级:Element UI <el-link> 的应用
链接升级:Element UI 的应用 一 . 创建文字链接1.1 注册路由1.2 创建文字链接 二 . 文字链接的属性2.1 文字链接的颜色2.2 是否显示下划线2.3 是否禁用状态2.4 填写跳转地址2.5 加入图标 在本篇文章中,我们将深入探索Element UI中的<el-link>组件—…...
使用docker在3台服务器上搭建基于redis 6.x的一主两从三台均是哨兵模式
一、环境及版本说明 如果服务器已经安装了docker,则忽略此步骤,如果没有安装,则可以按照一下方式安装: 1. 在线安装(有互联网环境): 请看我这篇文章 传送阵>> 点我查看 2. 离线安装(内网环境):请看我这篇文章 传送阵>> 点我查看 说明:假设每台服务器已…...
iOS 26 携众系统重磅更新,但“苹果智能”仍与国行无缘
美国西海岸的夏天,再次被苹果点燃。一年一度的全球开发者大会 WWDC25 如期而至,这不仅是开发者的盛宴,更是全球数亿苹果用户翘首以盼的科技春晚。今年,苹果依旧为我们带来了全家桶式的系统更新,包括 iOS 26、iPadOS 26…...
【Go】3、Go语言进阶与依赖管理
前言 本系列文章参考自稀土掘金上的 【字节内部课】公开课,做自我学习总结整理。 Go语言并发编程 Go语言原生支持并发编程,它的核心机制是 Goroutine 协程、Channel 通道,并基于CSP(Communicating Sequential Processes࿰…...
SpringTask-03.入门案例
一.入门案例 启动类: package com.sky;import lombok.extern.slf4j.Slf4j; import org.springframework.boot.SpringApplication; import org.springframework.boot.autoconfigure.SpringBootApplication; import org.springframework.cache.annotation.EnableCach…...
【网络安全】开源系统getshell漏洞挖掘
审计过程: 在入口文件admin/index.php中: 用户可以通过m,c,a等参数控制加载的文件和方法,在app/system/entrance.php中存在重点代码: 当M_TYPE system并且M_MODULE include时,会设置常量PATH_OWN_FILE为PATH_APP.M_T…...
省略号和可变参数模板
本文主要介绍如何展开可变参数的参数包 1.C语言的va_list展开可变参数 #include <iostream> #include <cstdarg>void printNumbers(int count, ...) {// 声明va_list类型的变量va_list args;// 使用va_start将可变参数写入变量argsva_start(args, count);for (in…...
(一)单例模式
一、前言 单例模式属于六大创建型模式,即在软件设计过程中,主要关注创建对象的结果,并不关心创建对象的过程及细节。创建型设计模式将类对象的实例化过程进行抽象化接口设计,从而隐藏了类对象的实例是如何被创建的,封装了软件系统使用的具体对象类型。 六大创建型模式包括…...
Rust 开发环境搭建
环境搭建 1、开发工具RustRover 或者vs code 2、Cygwin64 安装 https://cygwin.com/install.html 在工具终端执行: rustup toolchain install stable-x86_64-pc-windows-gnu rustup default stable-x86_64-pc-windows-gnu 2、Hello World fn main() { println…...
苹果AI眼镜:从“工具”到“社交姿态”的范式革命——重新定义AI交互入口的未来机会
在2025年的AI硬件浪潮中,苹果AI眼镜(Apple Glasses)正在引发一场关于“人机交互形态”的深度思考。它并非简单地替代AirPods或Apple Watch,而是开辟了一个全新的、日常可接受的AI入口。其核心价值不在于功能的堆叠,而在于如何通过形态设计打破社交壁垒,成为用户“全天佩戴…...
python爬虫——气象数据爬取
一、导入库与全局配置 python 运行 import json import datetime import time import requests from sqlalchemy import create_engine import csv import pandas as pd作用: 引入数据解析、网络请求、时间处理、数据库操作等所需库。requests:发送 …...
