当前位置: 首页 > news >正文

pytorch使用技巧

pytorch使用技巧

1. 指定GPU编号

设置当前使用的GPU设备仅为0号设备,设备名称为 /gpu:0os.environ["CUDA_VISIBLE_DEVICES"] = "0"

设置当前使用的GPU设备为0, 1号两个设备,名称依次为 /gpu:0/gpu:1: 
os.environ["CUDA_VISIBLE_DEVICES"] = "0,1" ,根据顺序表示优先使用0号设备,然后使用1号设备。

2. 查看模型每层输出详情

Keras有一个简洁的API来查看模型的每一层输出尺寸,这在调试网络时非常有用。现在在PyTorch中也可以实现这个功能。

from torchsummary import summarysummary(your_model, input_size=(channels, H, W))

input_size 是根据你自己的网络模型的输入尺寸进行设置。

https://github.com/sksq96/pytorch-summary

3. 梯度裁剪(Gradient Clipping)

import torch.nn as nn
outputs = model(data)loss= loss_fn(outputs, target)optimizer.zero_grad()loss.backward()nn.utils.clip_grad_norm_(model.parameters(), max_norm=20, norm_type=2)optimizer.step()

nn.utils.clip_grad_norm_ 的参数:

  • parameters – 一个基于变量的迭代器,会进行梯度归一化

  • max_norm – 梯度的最大范数

  • norm_type – 规定范数的类型,默认为L2

4. 扩展单张图片维度

因为在训练时的数据维度一般都是 (batch_size, c, h, w),而在测试时只输入一张图片,所以需要扩展维度,扩展维度有多个方法:

import cv2import torch
image = cv2.imread(img_path)image = torch.tensor(image)print(image.size())
img = image.view(1, *image.size())print(img.size())
# output:# torch.Size([h, w, c])# torch.Size([1, h, w, c])

import cv2import numpy as np
image = cv2.imread(img_path)print(image.shape)img = image[np.newaxis, :, :, :]print(img.shape)
# output:# (h, w, c)# (1, h, w, c)
import cv2import torch
image = cv2.imread(img_path)image = torch.tensor(image)print(image.size())
img = image.unsqueeze(dim=0)  print(img.size())
img = img.squeeze(dim=0)print(img.size())
# output:# torch.Size([(h, w, c)])# torch.Size([1, h, w, c])# torch.Size([h, w, c])

tensor.unsqueeze(dim):扩展维度,dim指定扩展哪个维度。

tensor.squeeze(dim):去除dim指定的且size为1的维度,维度大于1时,squeeze()不起作用,不指定dim时,去除所有size为1的维度。

5. 独热编码

在PyTorch中使用交叉熵损失函数的时候会自动把label转化成onehot,所以不用手动转化,而使用MSE需要手动转化成onehot编码。

import torchclass_num = 8batch_size = 4
def one_hot(label):    """    将一维列表转换为独热编码    """    label = label.resize_(batch_size, 1)    m_zeros = torch.zeros(batch_size, class_num)    # 从 value 中取值,然后根据 dim 和 index 给相应位置赋值    onehot = m_zeros.scatter_(1, label, 1)  # (dim,index,value)
    return onehot.numpy()  # Tensor -> Numpy
label = torch.LongTensor(batch_size).random_() % class_num  # 对随机数取余print(one_hot(label))
# output:[[0. 0. 0. 1. 0. 0. 0. 0.] [0. 0. 0. 0. 1. 0. 0. 0.] [0. 0. 1. 0. 0. 0. 0. 0.] [0. 1. 0. 0. 0. 0. 0. 0.]]

6. 防止验证模型时爆显存

验证模型时不需要求导,即不需要梯度计算,关闭autograd,可以提高速度,节约内存。如果不关闭可能会爆显存。

with torch.no_grad():    # 使用model进行预测的代码    pass

意思就是PyTorch的缓存分配器会事先分配一些固定的显存,即使实际上tensors并没有使用完这些显存,这些显存也不能被其他应用使用。这个分配过程由第一次CUDA内存访问触发的。

而 torch.cuda.empty_cache() 的作用就是释放缓存分配器当前持有的且未占用的缓存显存,以便这些显存可以被其他GPU应用程序中使用,并且通过 nvidia-smi命令可见。注意使用此命令不会释放tensors占用的显存。

对于不用的数据变量,Pytorch 可以自动进行回收从而释放相应的显存。

7. 学习率衰减

import torch.optim as optimfrom torch.optim import lr_scheduler
# 训练前的初始化optimizer = optim.Adam(net.parameters(), lr=0.001)scheduler = lr_scheduler.StepLR(optimizer, 10, 0.1)  # # 每过10个epoch,学习率乘以0.1
# 训练过程中for n in n_epoch:    scheduler.step()    ...

8. 冻结某些层的参数

参考:Pytorch 冻结预训练模型的某一层
https://www.zhihu.com/question/311095447/answer/589307812

在加载预训练模型的时候,我们有时想冻结前面几层,使其参数在训练过程中不发生变化。

我们需要先知道每一层的名字,通过如下代码打印:

net = Network()  # 获取自定义网络结构for name, value in net.named_parameters():    print('name: {0},\t grad: {1}'.format(name, value.requires_grad))
name: cnn.VGG_16.convolution1_1.weight,   grad: Truename: cnn.VGG_16.convolution1_1.bias,   grad: Truename: cnn.VGG_16.convolution1_2.weight,   grad: Truename: cnn.VGG_16.convolution1_2.bias,   grad: Truename: cnn.VGG_16.convolution2_1.weight,   grad: Truename: cnn.VGG_16.convolution2_1.bias,   grad: Truename: cnn.VGG_16.convolution2_2.weight,   grad: Truename: cnn.VGG_16.convolution2_2.bias,   grad: True

后面的True表示该层的参数可训练,然后我们定义一个要冻结的层的列表:

no_grad = [    'cnn.VGG_16.convolution1_1.weight',    'cnn.VGG_16.convolution1_1.bias',    'cnn.VGG_16.convolution1_2.weight',    'cnn.VGG_16.convolution1_2.bias']
net = Net.CTPN()  # 获取网络结构for name, value in net.named_parameters():    if name in no_grad:        value.requires_grad = False    else:        value.requires_grad = True
name: cnn.VGG_16.convolution1_1.weight,   grad: Falsename: cnn.VGG_16.convolution1_1.bias,   grad: Falsename: cnn.VGG_16.convolution1_2.weight,   grad: Falsename: cnn.VGG_16.convolution1_2.bias,   grad: Falsename: cnn.VGG_16.convolution2_1.weight,   grad: Truename: cnn.VGG_16.convolution2_1.bias,   grad: Truename: cnn.VGG_16.convolution2_2.weight,   grad: Truename: cnn.VGG_16.convolution2_2.bias,   grad: True

可以看到前两层的weight和bias的requires_grad都为False,表示它们不可训练。

最后在定义优化器时,只对requires_grad为True的层的参数进行更新。

optimizer = optim.Adam(filter(lambda p: p.requires_grad, net.parameters()), lr=0.01)

9. 对不同层使用不同学习率

net = Network()  # 获取自定义网络结构for name, value in net.named_parameters():    print('name: {}'.format(name))
# 输出:# name: cnn.VGG_16.convolution1_1.weight# name: cnn.VGG_16.convolution1_1.bias# name: cnn.VGG_16.convolution1_2.weight# name: cnn.VGG_16.convolution1_2.bias# name: cnn.VGG_16.convolution2_1.weight# name: cnn.VGG_16.convolution2_1.bias# name: cnn.VGG_16.convolution2_2.weight# name: cnn.VGG_16.convolution2_2.bias

对 convolution1 和 convolution2 设置不同的学习率,首先将它们分开,即放到不同的列表里:

conv1_params = []conv2_params = []
for name, parms in net.named_parameters():    if "convolution1" in name:        conv1_params += [parms]    else:        conv2_params += [parms]
# 然后在优化器中进行如下操作:optimizer = optim.Adam(    [        {"params": conv1_params, 'lr': 0.01},        {"params": conv2_params, 'lr': 0.001},    ],    weight_decay=1e-3,)

我们将模型划分为两部分,存放到一个列表里,每部分就对应上面的一个字典,在字典里设置不同的学习率。当这两部分有相同的其他参数时,就将该参数放到列表外面作为全局参数,如上面的`weight_decay`。

也可以在列表外设置一个全局学习率,当各部分字典里设置了局部学习率时,就使用该学习率,否则就使用列表外的全局学习率。

相关文章:

pytorch使用技巧

pytorch使用技巧 1. 指定GPU编号 设置当前使用的GPU设备仅为0号设备,设备名称为 /gpu:0os.environ["CUDA_VISIBLE_DEVICES"] "0" 设置当前使用的GPU设备为0, 1号两个设备,名称依次为 /gpu:0、/gpu:1: os.environ[&quo…...

从用户数据到区块链:Facebook如何利用去中心化技术

在数字化时代,用户数据的管理和保护已成为科技公司面临的重大挑战。作为全球最大的社交网络平台之一,Facebook不仅在用户数据的处理上积累了丰富的经验,也在探索如何利用去中心化技术,如区块链,来改进其数据管理和用户…...

Elasticsearch之bool查询

bool 查询是 Elasticsearch 中最常用的复合查询类型,允许将多个查询组合在一起。它通过逻辑操作符(如 must、should、must_not 和 filter)来构建复杂的查询条件,从而满足多条件匹配、逻辑与(AND)、或&#…...

IntelliJ IDEA 创建 Java 项目指南

IntelliJ IDEA 是一款功能强大的集成开发环境(IDE),广泛用于 Java 开发。本文将介绍如何在 IntelliJ IDEA 中创建一个新的 Java 项目,包括环境的设置和基本配置。更多问题,请查阅 一、安装 IntelliJ IDEA 1. 下载 In…...

一起学Java(13)-[日志篇]教你分析SLF4J和Log4j2源码,掌握SLF4J与Log4j2桥接集成原理

研究完SLF4J和Logback这种无缝集成的方式(一起学Java(12)-[日志篇]教你分析SLF4J源码,掌握SLF4J如何与Logback无缝集成的原理),继续研究Log4j2和SLF4J这种需要桥接集成的方式。 一、桥接包如何与SLF4J集成 我们已经知道SLF4J利用ServiceLoader机制&…...

深入Redis:核心的缓存

Redis最主要的用途,主要有三个方面:存储数据、缓存、消息队列。 其中,缓存是Redis最常用的场景。Redis使用内存作为硬盘的缓存。把用户集中访问的20%数据放到缓存中去,可以应对80%的请求。 数据库是非常重要的组件,但…...

集群聊天服务器项目【C++】项目介绍和环境搭建

前言:学习一个基于C集群聊天服务器的项目,记录学习的内容和学习的过程。 1.项目介绍 在 Linux 环境下基于 muduo 开发的集群聊天服务器。实现新用户注册、用户登录、添加好友、添加群组、好友通信、群组聊天、保持离线消息等功能。 2.技术栈 Json序列…...

c++ #include <memory> 智能指针介绍

#include <memory> 是 C 标准库中的头文件&#xff0c;用于支持智能指针的功能。智能指针是现代 C 的一种资源管理工具&#xff0c;用于自动管理动态分配的内存&#xff0c;从而减少内存泄漏和悬挂指针等问题的发生。它提供了多种类型的智能指针&#xff0c;包括 std::un…...

32.递归、搜索、回溯之floodfill算法

0.简介 1.图像渲染 . - 力扣&#xff08;LeetCode&#xff09; 题目解析 算法原理 代码 class Solution {int[] dx { 0, 0, 1, -1 };int[] dy { 1, -1, 0, 0 };int m, n;int prev;public int[][] floodFill(int[][] image, int sr, int sc, int color) {if (image[sr][sc]…...

Vue3.5+ 响应式 Props 解构

你好同学&#xff0c;我是沐爸&#xff0c;欢迎点赞、收藏、评论和关注。 在 Vue 3.5 中&#xff0c;响应式 Props 解构已经稳定并默认启用。这意味着在 <script setup> 中从 defineProps 调用解构的变量现在是响应式的。这一改进大大简化了声明带有默认值的 props 的方…...

k8s中的认证授权

目录 一、kubernetes API 访问控制 1.1 UserAccount与ServiceAccount 1.1.1 ServiceAccount 1.1.2 ServiceAccount示例 二、认证(在k8s中建立认证用户) 2.1 创建UserAccount 2.2 RBAC&#xff08;Role Based Access Control&#xff09; 2.2.1 基于角色访问控制授权&…...

Leetcode 3291. Minimum Number of Valid Strings to Form Target I

Leetcode 3291. Minimum Number of Valid Strings to Form Target I 1. 解题思路2. 代码实现 题目链接&#xff1a;3291. Minimum Number of Valid Strings to Form Target I 1. 解题思路 这一题第一反应就是用一个字典树动态规划的方式&#xff0c;倒是也搞定了&#xff0c…...

PostgreSQL的查看主从同步状态

PostgreSQL的查看主从同步状态 PostgreSQL 提供了一些系统视图和函数&#xff0c;查看和监控主从同步的状态。 1 在主节点上查看同步状态 pg_stat_replication 视图 在主节点上&#xff0c;可以通过查询 pg_stat_replication 视图来查看复制的详细状态信息&#xff0c;包括…...

Java多态性的理解

方法的覆盖 子类的方法重写了父类的方法&#xff0c;相当于对原来的方法进行了增强&#xff0c;接口就是这样的思想。 属性的隔离&#xff08;Java中什么情况下都不会属性覆盖&#xff0c;python可能会覆盖&#xff09; public class Main {public static void main(String[…...

安全工具 | 使用Burp Suite的10个小tips

Burp Suite 应用程序中有用功能的集合 img Burp Suite 是一款出色的分析工具&#xff0c;用于测试 Web 应用程序和系统的安全漏洞。它有很多很棒的功能可以在渗透测试中使用。您使用它的次数越多&#xff0c;您就越发现它的便利功能。 本文内容是我在测试期间学到并经常的主要…...

企业项目中字符串工具类

此工具类暂时包含如下功能&#xff1a; isEmpty()判断字符串是否为空subSpecifiedString()判断字符串是否超出指定长度&#xff0c;超出则截取到指定长度yearMonthToDate()将年月的字符串转成年月日格式 yearMonthToDateTime()将年月的字符串转成年月日时分秒格式 package co…...

下载github patch到本地

以下是几种从 GitHub 上下载以.patch 结尾的补丁文件的方法&#xff1a; 通过浏览器直接下载 打开包含该.patch 文件的 GitHub 仓库。在仓库的文件列表中找到对应的.patch 文件。点击该文件&#xff0c;浏览器会显示文件的内容&#xff0c;在页面的右上角通常会有一个“Raw”…...

C++基础部分代码

C OOP面对对象 this指针 C:各种各样的函数定义 struct C&#xff1a;类》实体的抽象类型 实体(属性&#xff0c;行为)-》ADT&#xff08;abstract data type) OOP语言的四大特征是什么&#xff1f; 抽象 封装/隐藏 继承 多态 访问限定符&#xff1a;public公有的 private私有的…...

neo4j(spring) 使用示例

文章目录 前言一、neo4j是什么二、开始编码1. yml 配置2. crud 测试3. node relation 与java中对象的关系4. 编码测试 总结 前言 图数据库先驱者 neo4j&#xff1a;neo4j官网地址 可以选择桌面版安装等多种方式,我这里采用的是docker安装 直接执行docker安装命令: docker run…...

链接升级:Element UI <el-link> 的应用

链接升级&#xff1a;Element UI 的应用 一 . 创建文字链接1.1 注册路由1.2 创建文字链接 二 . 文字链接的属性2.1 文字链接的颜色2.2 是否显示下划线2.3 是否禁用状态2.4 填写跳转地址2.5 加入图标 在本篇文章中&#xff0c;我们将深入探索Element UI中的<el-link>组件—…...

Lombok 的 @Data 注解失效,未生成 getter/setter 方法引发的HTTP 406 错误

HTTP 状态码 406 (Not Acceptable) 和 500 (Internal Server Error) 是两类完全不同的错误&#xff0c;它们的含义、原因和解决方法都有显著区别。以下是详细对比&#xff1a; 1. HTTP 406 (Not Acceptable) 含义&#xff1a; 客户端请求的内容类型与服务器支持的内容类型不匹…...

自然语言处理——Transformer

自然语言处理——Transformer 自注意力机制多头注意力机制Transformer 虽然循环神经网络可以对具有序列特性的数据非常有效&#xff0c;它能挖掘数据中的时序信息以及语义信息&#xff0c;但是它有一个很大的缺陷——很难并行化。 我们可以考虑用CNN来替代RNN&#xff0c;但是…...

网络编程(UDP编程)

思维导图 UDP基础编程&#xff08;单播&#xff09; 1.流程图 服务器&#xff1a;短信的接收方 创建套接字 (socket)-----------------------------------------》有手机指定网络信息-----------------------------------------------》有号码绑定套接字 (bind)--------------…...

Springboot社区养老保险系统小程序

一、前言 随着我国经济迅速发展&#xff0c;人们对手机的需求越来越大&#xff0c;各种手机软件也都在被广泛应用&#xff0c;但是对于手机进行数据信息管理&#xff0c;对于手机的各种软件也是备受用户的喜爱&#xff0c;社区养老保险系统小程序被用户普遍使用&#xff0c;为方…...

初探Service服务发现机制

1.Service简介 Service是将运行在一组Pod上的应用程序发布为网络服务的抽象方法。 主要功能&#xff1a;服务发现和负载均衡。 Service类型的包括ClusterIP类型、NodePort类型、LoadBalancer类型、ExternalName类型 2.Endpoints简介 Endpoints是一种Kubernetes资源&#xf…...

JS手写代码篇----使用Promise封装AJAX请求

15、使用Promise封装AJAX请求 promise就有reject和resolve了&#xff0c;就不必写成功和失败的回调函数了 const BASEURL ./手写ajax/test.jsonfunction promiseAjax() {return new Promise((resolve, reject) > {const xhr new XMLHttpRequest();xhr.open("get&quo…...

省略号和可变参数模板

本文主要介绍如何展开可变参数的参数包 1.C语言的va_list展开可变参数 #include <iostream> #include <cstdarg>void printNumbers(int count, ...) {// 声明va_list类型的变量va_list args;// 使用va_start将可变参数写入变量argsva_start(args, count);for (in…...

Python的__call__ 方法

在 Python 中&#xff0c;__call__ 是一个特殊的魔术方法&#xff08;magic method&#xff09;&#xff0c;它允许一个类的实例像函数一样被调用。当你在一个对象后面加上 () 并执行时&#xff08;例如 obj()&#xff09;&#xff0c;Python 会自动调用该对象的 __call__ 方法…...

新版NANO下载烧录过程

一、序言 搭建 Jetson 系列产品烧录系统的环境需要在电脑主机上安装 Ubuntu 系统。此处使用 18.04 LTS。 二、环境搭建 1、安装库 $ sudo apt-get install qemu-user-static$ sudo apt-get install python 搭建环境的过程需要这个应用库来将某些 NVIDIA 软件组件安装到 Je…...

【论文解读】MemGPT: 迈向为操作系统的LLM

1st author: Charles Packer paper MemGPT[2310.08560] MemGPT: Towards LLMs as Operating Systems code: letta-ai/letta: Letta (formerly MemGPT) is the stateful agents framework with memory, reasoning, and context management. 这个项目现在已经转化为 Letta &a…...