当前位置: 首页 > news >正文

Pytorch 复习总结 4

Pytorch 复习总结,仅供笔者使用,参考教材:

  • 《动手学深度学习》
  • Stanford University: Practical Machine Learning

本文主要内容为:Pytorch 深度学习计算。

本文先介绍了深度学习中自定义层和块的方法,然后介绍了一些有关参数的方法。


Pytorch 语法汇总:

  • Pytorch 张量的常见运算、线性代数、高等数学、概率论 部分 见 Pytorch 复习总结1;
  • Pytorch 线性神经网络 部分 见 Pytorch 复习总结2;
  • Pytorch 多层感知机 部分 见 Pytorch 复习总结3;
  • Pytorch 深度学习计算 部分 见 Pytorch 复习总结4;
  • Pytorch 卷积神经网络 部分 见 Pytorch 复习总结5;
  • Pytorch 现代卷积神经网络 部分 见 Pytorch 复习总结6;

目录

  • 一. 自定义块
    • 1. 顺序块
    • 2. 自定义前向传播
    • 3. 嵌套块
  • 二. 自定义层
    • 1. 无参数层
    • 2. 有参数层
  • 三. 参数管理
    • 1. 参数访问
    • 2. 参数初始化
    • 3. 延后初始化
  • 四. 文件读写
    • 1. 加载和保存张量
    • 2. 加载和保存模型参数

层是神经网络的基本组成单元,如全连接层、卷积层、池化层等。块是由层组成的更大的功能单元,用于构建复杂的神经网络结构。块可以是一系列相互关联的层,形成一个功能完整的单元,也可以是一组层的重复模式,用于实现重复的结构。下图就是多个层组合成块形成的更大模型:
在这里插入图片描述

在实际应用中,经常会需要自定义层和块。

一. 自定义块

1. 顺序块

nn.Sequential 本质上就是一个顺序块,通过在块中实例化层来创建神经网络。 nn.Module 是 PyTorch 中用于构建神经网络模型的基类,nn.Sequential 和各种层都是继承自 Module,nn.Sequential 维护一个由多个层组成的有序列表,列表中的每个层连接在一起,将每个层的输出作为下一个层的输入。

如果想要自定义一个顺序块,必须要定义以下两个关键函数:

  1. 构造函数:将每个层按顺序逐个加入列表;
  2. 前向传播函数:将每一层按顺序传递给下一层;
import torch
from torch import nnclass MySequential(nn.Module):def __init__(self, *args):super().__init__()for idx, module in enumerate(args):self._modules[str(idx)] = moduledef forward(self, X):# self._modules的类型是OrderedDictfor block in self._modules.values():X = block(X)return Xnet = MySequential(nn.Linear(20, 256),nn.ReLU(),nn.Linear(256, 10)
)X = torch.rand(2, 20)
output = net(X)

上述示例代码中,定义 net 时会自动调用 __init__(self, *args) 函数,实例化 MySequential 对象;调用 net(X) 相当于 net.__call__(X),会自动调用模型类中定义的 forward() 函数,进行前向传播,每一层的传播本质上就是调用 block(X) 的过程。

2. 自定义前向传播

nn.Sequential 类将前向传播过程封装成函数,用户可以自由使用但没法修改传播细节。如果想要自定义前向传播过程中的细节,就需要自定义顺序块及 forward 函数,而不能仅仅依赖预定义的框架。

例如,需要一个计算函数 f ( x , w ) = c ⋅ w T x f(\bold x,\bold w)=c \cdot \bold w ^T \bold x f(x,w)=cwTx 的层,并且在传播过程中引入控制流。其中 x \bold x x 是输入, w \bold w w 是参数, c c c 是优化过程中不需要更新的指定常量。为此,定义 FixedHiddenMLP 类如下:

import torch
from torch import nn
from torch.nn import functional as Fclass FixedHiddenMLP(nn.Module):def __init__(self):super().__init__()self.rand_weight = torch.rand((20, 20), requires_grad=False)    # 优化过程中不需要更新的指定常量self.linear = nn.Linear(20, 20)def forward(self, X):X = self.linear(X)X = F.relu(torch.mm(X, self.rand_weight) + 1)X = self.linear(X)          # 两个全连接层共享参数while X.abs().sum() > 1:    # 控制流X /= 2return X

3. 嵌套块

多个层可以组合成块,多个块还可以嵌套形成更大的模型:

import torch
from torch import nn
from torch.nn import functional as Fclass FixedHiddenMLP(nn.Module):def __init__(self):super().__init__()self.rand_weight = torch.rand((20, 20), requires_grad=False)    # 优化过程中不需要更新的指定常量self.linear = nn.Linear(20, 20)def forward(self, X):X = self.linear(X)X = F.relu(torch.mm(X, self.rand_weight) + 1)X = self.linear(X)          # 两个全连接层共享参数while X.abs().sum() > 1:    # 控制流X /= 2return X.sum()class NestMLP(nn.Module):def __init__(self):super().__init__()self.net = nn.Sequential(nn.Linear(20, 64), nn.ReLU(),nn.Linear(64, 32), nn.ReLU())self.linear = nn.Linear(32, 16)def forward(self, X):return self.linear(self.net(X))net = nn.Sequential(NestMLP(), nn.Linear(16, 20), FixedHiddenMLP()
)X = torch.rand(2, 20)
output = net(X)

二. 自定义层

和自定义块一样,自定义层也需要实现构造函数和前向传播函数。

1. 无参数层

import torch
from torch import nnclass CenteredLayer(nn.Module):def __init__(self):super().__init__()def forward(self, X):return X - X.mean()net = nn.Sequential(nn.Linear(8, 128), CenteredLayer())
X = torch.rand(4, 8)
output = net(X)
print(output.mean())	# tensor(0., grad_fn=<MeanBackward0>)

2. 有参数层

import torch
from torch import nn
import torch.nn.functional as Fclass MyLinear(nn.Module):def __init__(self, in_units, out_units):super().__init__()self.weight = nn.Parameter(torch.randn(in_units, out_units))self.bias = nn.Parameter(torch.randn(out_units,))def forward(self, X):linear = torch.matmul(X, self.weight.data) + self.bias.datareturn F.relu(linear)net = nn.Sequential(MyLinear(64, 8), MyLinear(8, 1)
)
X = torch.rand(2, 64)
output = net(X)
print(output)       # tensor([[11.9497], [13.9729]])

三. 参数管理

在实验过程中,有时需要提取参数,以便检查或在其他环境中复用。本节将介绍参数的访问方法和参数的初始化。

1. 参数访问

  • net.state_dict() / net[i].state_dict():返回模型或某一层参数的状态字典;
  • net[i].weight.data / net[i].bias.data:返回某一层的权重 / 偏置参数;
  • net[i].weight.grad:返回某一层的权重参数的梯度属性。只有调用了 backward() 方法后才能访问到梯度值,否则为 None;
import torch
from torch import nnnet = nn.Sequential(nn.Linear(4, 8), nn.ReLU(), nn.Linear(8, 1))
X = torch.rand(size=(2, 4))
output = net(X)print(net.state_dict())
'''
OrderedDict([('0.weight', tensor([[ 0.2178, -0.3286,  0.4875, -0.0347],[-0.0415,  0.0009, -0.2038, -0.1813],[-0.2766, -0.4759, -0.3134, -0.2782],[ 0.4854,  0.0606,  0.1070,  0.0650],[-0.3908,  0.2412, -0.1348,  0.3921],[-0.3044, -0.0331, -0.1213, -0.1690],[-0.3875, -0.0117,  0.3195, -0.1748],[ 0.1840, -0.3502,  0.4253,  0.2789]])), ('0.bias', tensor([-0.2327, -0.0745,  0.4923, -0.1018,  0.0685,  0.4423, -0.2979,  0.1109])), ('2.weight', tensor([[ 0.1006,  0.2959, -0.1316, -0.2015,  0.2446, -0.0158,  0.2217, -0.2780]])), ('2.bias', tensor([0.2362]))])
'''
print(net[2].state_dict())
'''
OrderedDict([('weight', tensor([[ 0.1006,  0.2959, -0.1316, -0.2015,  0.2446, -0.0158,  0.2217, -0.2780]])), ('bias', tensor([0.2362]))])
'''
print(net[2].bias)
'''
Parameter containing:
tensor([0.2362], requires_grad=True)
'''
print(net[2].bias.data)
'''
tensor([0.2362])
'''

如果想一次性访问所有参数,可以使用 for 循环递归遍历:

import torch
from torch import nnnet = nn.Sequential(nn.Linear(4, 8), nn.ReLU(), nn.Linear(8, 1))
X = torch.rand(size=(2, 4))
output = net(X)print(*[(name, param.shape) for name, param in net[0].named_parameters()])
'''
('weight', torch.Size([8, 4])) ('bias', torch.Size([8]))
'''
print(*[(name, param.shape) for name, param in net.named_parameters()])
'''
('0.weight', torch.Size([8, 4])) ('0.bias', torch.Size([8])) ('2.weight', torch.Size([1, 8])) ('2.bias', torch.Size([1]))
'''

如果网络是由多个块相互嵌套的,可以按块索引后再访问参数:

import torch
from torch import nndef block1():return nn.Sequential(nn.Linear(4, 8), nn.ReLU(),nn.Linear(8, 4), nn.ReLU())def block2():net = nn.Sequential()for i in range(4):net.add_module(f'block {i}', block1())return netnet = nn.Sequential(block2(), nn.Linear(4, 1))
X = torch.rand(size=(2, 4))
output = net(X)print(net)
'''
Sequential((0): Sequential((block 0): Sequential((0): Linear(in_features=4, out_features=8, bias=True)(1): ReLU()(2): Linear(in_features=8, out_features=4, bias=True)(3): ReLU())(block 1): Sequential((0): Linear(in_features=4, out_features=8, bias=True)(1): ReLU()(2): Linear(in_features=8, out_features=4, bias=True)(3): ReLU())(block 2): Sequential((0): Linear(in_features=4, out_features=8, bias=True)(1): ReLU()(2): Linear(in_features=8, out_features=4, bias=True)(3): ReLU())(block 3): Sequential((0): Linear(in_features=4, out_features=8, bias=True)(1): ReLU()(2): Linear(in_features=8, out_features=4, bias=True)(3): ReLU()))(1): Linear(in_features=4, out_features=1, bias=True)
)
'''
print(net[0][1][0].bias.data)
'''
tensor([-0.0083,  0.2490,  0.1794,  0.1927,  0.1797,  0.1156,  0.4409,  0.1320])
'''

2. 参数初始化

PyTorch 的 nn.init 模块提供了多种初始化方法:

  • nn.init.constant_(layer.weight, c):将权重参数初始化为指定的常量值;
  • nn.init.zeros_(layer.weight):将权重参数初始化为 0;
  • nn.init.ones_(layer.weight):将权重参数初始化为 1;
  • nn.init.uniform_(layer.weight, a, b):将权重参数按均匀分布初始化;
  • nn.init.xavier_uniform_(layer.weight)
  • nn.init.normal_(layer.weight, mean, std):将权重参数按正态分布初始化;
  • nn.init.orthogonal_(layer.weight):将权重参数初始化为正交矩阵;
  • nn.init.sparse_(layer.weight, sparsity, std):将权重参数初始化为稀疏矩阵;

初始化时,可以直接 net.apply(init_method) 初始化整个网络,也可以 net[i].apply(init_method) 初始化某一层:

import torch
from torch import nndef init_normal(m):if type(m) == nn.Linear:nn.init.normal_(m.weight, mean=0, std=0.01)nn.init.zeros_(m.bias)def init_constant(m):if type(m) == nn.Linear:nn.init.constant_(m.weight, 1)nn.init.zeros_(m.bias)net = nn.Sequential(nn.Linear(4, 8), nn.ReLU(), nn.Linear(8, 1))
X = torch.rand(size=(2, 4))
output = net(X)# net.apply(init_normal)
net[0].apply(init_normal)
net[2].apply(init_constant)

3. 延后初始化

有些情况下,无法提前判断网络的输入维度。为了代码能够继续运行,需要使用延后初始化,即直到数据第一次通过模型传递时,框架才会动态地推断出每个层的大小。由于 PyTorch 的延后初始化功能还处于开发阶段,API 和功能随时可能变化,下面只给出简单示例:

import torch
from torch import nnnet = nn.Sequential(nn.LazyLinear(256), nn.ReLU(), nn.LazyLinear(10))print(net)
'''
Sequential((0): LazyLinear(in_features=0, out_features=256, bias=True)(1): ReLU()(2): LazyLinear(in_features=0, out_features=10, bias=True)
)
'''X = torch.rand(2, 20)
net(X)
print(net)
'''
Sequential((0): Linear(in_features=20, out_features=256, bias=True)(1): ReLU()(2): Linear(in_features=256, out_features=10, bias=True)
)
'''

四. 文件读写

可以使用 loadsave 函数读写张量和模型参数。

1. 加载和保存张量

2. 加载和保存模型参数

相关文章:

Pytorch 复习总结 4

Pytorch 复习总结&#xff0c;仅供笔者使用&#xff0c;参考教材&#xff1a; 《动手学深度学习》Stanford University: Practical Machine Learning 本文主要内容为&#xff1a;Pytorch 深度学习计算。 本文先介绍了深度学习中自定义层和块的方法&#xff0c;然后介绍了一些…...

YOLOv9中加入SCConv模块!

专栏介绍&#xff1a;YOLOv9改进系列 | 包含深度学习最新创新&#xff0c;主力高效涨点&#xff01;&#xff01;&#xff01; 一、本文介绍 本文将一步步演示如何在YOLOv9中添加 / 替换新模块&#xff0c;寻找模型上的创新&#xff01; 适用检测目标&#xff1a; YOLOv9模块…...

代码随想录算法训练营第四十七天丨198. 打家劫舍、​ 213. 打家劫舍 II​、337. 打家劫舍 III

198. 打家劫舍 自己的思路&#xff1a; 初始化两个dp数组&#xff0c;dp[i][0]表示不偷第i户&#xff0c;在0-i户可以偷到的最大金额&#xff0c;dp[i][1]表示偷i户在0-i户可以偷到的最大金额。 class Solution:def rob(self, nums: List[int]) -> int:n len(nums)dp […...

龙蜥Anolis 8.4 anck 安装mysql5.7

el8没有用mysql5.7了&#xff0c;镜像里是mysql8。 禁用 sudo dnf remove mysql sudo dnf module reset mysql sudo dnf module disable mysql 修改Yum源 sudo vi /etc/yum.repos.d/mysql-community.repo [mysql57-community] nameMySQL 5.7 Community Server baseurlhttp:…...

【踩坑】修复xrdp无法关闭Authentication Required验证窗口

转载请注明出处&#xff1a;小锋学长生活大爆炸[xfxuezhang.cn] 问题如下&#xff0c;时不时出现&#xff0c;有时还怎么都关不掉&#xff0c;很烦&#xff1a; 解决方法一&#xff1a;命令行输入 dbus-send --typemethod_call --destorg.gnome.Shell /org/gnome/Shell org.gn…...

python学习笔记 - 标准库常量

Python 中有一些内置的常量&#xff0c;它们是一些特殊的值&#xff0c;通常不会改变。以下是其中一些常见的内置常量及其详细解释以及使用示例&#xff1a; True&#xff1a; 表示布尔值真。给 True 赋值是非法的并会引发 SyntaxError。 x True print(x) # 输出&#xff1a…...

视频和音频使用ffmpeg进行合并和分离(MP4)

1.下载ffmpeg 官网地址&#xff1a;https://ffmpeg.org/download.html 2.配置环境变量 此电脑右键点击 属性 - 高级系统配置 -高级 -环境变量 - 系统变量 path 新增 文件的bin路径 3.验证配置成功 ffmpeg -version 返回版本信息说明配置成功4.执行合并 ffmpeg -i 武家坡20…...

02| JVM堆中垃圾回收的大致过程

如果一直在创建对象&#xff0c;堆中年轻代中Eden区会逐渐放满&#xff0c;如果Eden放满&#xff0c;会触发minor GC回收&#xff0c;创建对象的时GC Roots&#xff0c;如果存在于里面的对象&#xff0c;则被视为非垃圾对象&#xff0c;不会被此次gc回收&#xff0c;就会被移入…...

R语言数据可视化之美专业图表绘制指南(增强版):第1章 R语言编程与绘图基础

第1章 R语言编程与绘图基础 目录 第1章 R语言编程与绘图基础前言1.1 学术图表的基本概念1.1.1 学术图表的基本作用1.1.2基本类别1.1.3 学术图表的绘制原则 1.2 你为什么要选择R1.3 安装 前言 这是我第一次在博客里展示学习中国作者的教材的笔记。我选择这本书的依据是作者同时…...

网站添加pwa操作和配置manifest.json后,没有效果排查问题

pwa技术官网&#xff1a;https://web.dev/learn/pwa 应用清单manifest.json文件字段说明&#xff1a;https://web.dev/articles/add-manifest?hlzh-cn Web App Manifest&#xff1a;Web App Manifest | MDN 当网站添加了manifest.json文件后&#xff0c;也引入到html中了&a…...

MongoDB聚合运算符:$cosh

文章目录 语法使用举例双曲余弦值角度双曲余弦值弧度 $cosh聚合运算符用来计算双曲余弦值&#xff0c;返回指定表达式的双曲余弦值。 语法 { $cosh: <expression> }<expression>为可被解析为数值的表达式$cosh返回弧度&#xff0c;使用$radiansToDegrees运算符可…...

Jenkins配置在远程服务器上执行shell脚本(两种方式)

Jenkins配置在远程服务器上执行shell脚本 方式一&#xff1a;通过SSH免密方式执行 说明&#xff1a;Jenkins部署在ServerA&#xff1a;10.1.1.74上&#xff0c;要运行的程序在ServerB&#xff1a;10.1.1.196 分两步 第一步&#xff1a;Linux Centos7配置SSH免密登录 Linux…...

Java+SpringBoot,打造社区疫情信息新生态

✍✍计算机编程指导师 ⭐⭐个人介绍&#xff1a;自己非常喜欢研究技术问题&#xff01;专业做Java、Python、微信小程序、安卓、大数据、爬虫、Golang、大屏等实战项目。 ⛽⛽实战项目&#xff1a;有源码或者技术上的问题欢迎在评论区一起讨论交流&#xff01; ⚡⚡ Java实战 |…...

js ES6判断字符串是否以某个字符串开头或者结尾startsWith、endsWith

1.前言 startsWith&#xff1a;startsWith方法用于检查字符串是否以指定的字符串开头。 endsWith&#xff1a;endsWith方法用于检查字符串是否以指定的字符串结尾。 2.用法示例 const str Hello, world!;console.log(str.startsWith(Hello)); // true console.log(str.starts…...

预研项目完成后小批量验证(技术变更流程)

...

Bert-as-service 实战

参考&#xff1a;bert-as-service 详细使用指南写给初学者-CSDN博客 GitHub - ymcui/Chinese-BERT-wwm: Pre-Training with Whole Word Masking for Chinese BERT&#xff08;中文BERT-wwm系列模型&#xff09; 下载&#xff1a;https://storage.googleapis.com/bert_models/…...

微信小程序(四十七)多个token存储

注释很详细&#xff0c;直接上代码 新增内容&#xff1a; 1.基础存储模板 2.中括号实现变量名匹配 源码&#xff1a; app.js App({//提前声明的变量名token:wx.getStorageSync(toke),refreshToken:wx.getSystemInfoAsync(refreshToken),setToken(key,token){//保存token到全局…...

机器学习(II)--样本不平衡

现实中&#xff0c;样本&#xff08;类别&#xff09;样本不平衡&#xff08;class-imbalance&#xff09;是一种常见的现象&#xff0c;如&#xff1a;金融欺诈交易检测&#xff0c;欺诈交易的订单样本通常是占总交易数量的极少部分&#xff0c;而且对于有些任务而言少数样本更…...

几个好用的 VUE Table

Vue easytable - 功能恰到好处 无学习成本 上手就用Vue good table - UI 清新 功能直给 适合小项目Vxe table - 宝藏级 table 组件 高级功能低调好用 维护频率高tabulator - 元老级 table 组件 高级功能平民化AG Grid - 媲美 Excel 的 Table 组件 能想到的复杂功能它都能做到...

Vue源码系列讲解——实例方法篇【三】(生命周期相关方法)

目录 0. 前言 1. vm.$mount 1.1 用法回顾 1.2 内部原理 2. vm.$forceUpdate 2.1 用法回顾 2.2 内部原理 3. vm.$nextTick 3.1 用法回顾 3.2 JS的运行机制 3.3 内部原理 能力检测 执行回调队列 4. vm.$destory 4.1 用法回顾 4.2 内部原理 0. 前言 与生命周期相关…...

vscode里如何用git

打开vs终端执行如下&#xff1a; 1 初始化 Git 仓库&#xff08;如果尚未初始化&#xff09; git init 2 添加文件到 Git 仓库 git add . 3 使用 git commit 命令来提交你的更改。确保在提交时加上一个有用的消息。 git commit -m "备注信息" 4 …...

JVM垃圾回收机制全解析

Java虚拟机&#xff08;JVM&#xff09;中的垃圾收集器&#xff08;Garbage Collector&#xff0c;简称GC&#xff09;是用于自动管理内存的机制。它负责识别和清除不再被程序使用的对象&#xff0c;从而释放内存空间&#xff0c;避免内存泄漏和内存溢出等问题。垃圾收集器在Ja…...

VTK如何让部分单位不可见

最近遇到一个需求&#xff0c;需要让一个vtkDataSet中的部分单元不可见&#xff0c;查阅了一些资料大概有以下几种方式 1.通过颜色映射表来进行&#xff0c;是最正规的做法 vtkNew<vtkLookupTable> lut; //值为0不显示&#xff0c;主要是最后一个参数&#xff0c;透明度…...

【Java_EE】Spring MVC

目录 Spring Web MVC ​编辑注解 RestController RequestMapping RequestParam RequestParam RequestBody PathVariable RequestPart 参数传递 注意事项 ​编辑参数重命名 RequestParam ​编辑​编辑传递集合 RequestParam 传递JSON数据 ​编辑RequestBody ​…...

精益数据分析(97/126):邮件营销与用户参与度的关键指标优化指南

精益数据分析&#xff08;97/126&#xff09;&#xff1a;邮件营销与用户参与度的关键指标优化指南 在数字化营销时代&#xff0c;邮件列表效度、用户参与度和网站性能等指标往往决定着创业公司的增长成败。今天&#xff0c;我们将深入解析邮件打开率、网站可用性、页面参与时…...

零基础在实践中学习网络安全-皮卡丘靶场(第九期-Unsafe Fileupload模块)(yakit方式)

本期内容并不是很难&#xff0c;相信大家会学的很愉快&#xff0c;当然对于有后端基础的朋友来说&#xff0c;本期内容更加容易了解&#xff0c;当然没有基础的也别担心&#xff0c;本期内容会详细解释有关内容 本期用到的软件&#xff1a;yakit&#xff08;因为经过之前好多期…...

在Ubuntu24上采用Wine打开SourceInsight

1. 安装wine sudo apt install wine 2. 安装32位库支持,SourceInsight是32位程序 sudo dpkg --add-architecture i386 sudo apt update sudo apt install wine32:i386 3. 验证安装 wine --version 4. 安装必要的字体和库(解决显示问题) sudo apt install fonts-wqy…...

基于TurtleBot3在Gazebo地图实现机器人远程控制

1. TurtleBot3环境配置 # 下载TurtleBot3核心包 mkdir -p ~/catkin_ws/src cd ~/catkin_ws/src git clone -b noetic-devel https://github.com/ROBOTIS-GIT/turtlebot3.git git clone -b noetic https://github.com/ROBOTIS-GIT/turtlebot3_msgs.git git clone -b noetic-dev…...

android RelativeLayout布局

<?xml version"1.0" encoding"utf-8"?> <RelativeLayout xmlns:android"http://schemas.android.com/apk/res/android"android:layout_width"match_parent"android:layout_height"match_parent"android:gravity&…...

mac:大模型系列测试

0 MAC 前几天经过学生优惠以及国补17K入手了mac studio,然后这两天亲自测试其模型行运用能力如何&#xff0c;是否支持微调、推理速度等能力。下面进入正文。 1 mac 与 unsloth 按照下面的进行安装以及测试&#xff0c;是可以跑通文章里面的代码。训练速度也是很快的。 注意…...