Pytorch 复习总结 3
Pytorch 复习总结,仅供笔者使用,参考教材:
- 《动手学深度学习》
- Stanford University: Practical Machine Learning
本文主要内容为:Pytorch 多层感知机。
本文先介绍了多层感知机的用法,再就训练过程中经常出现的过拟合现象提出解决办法。
Pytorch 语法汇总:
- Pytorch 张量的常见运算、线性代数、高等数学、概率论 部分 见 Pytorch 复习总结1;
- Pytorch 线性神经网络 部分 见 Pytorch 复习总结2;
- Pytorch 多层感知机 部分 见 Pytorch 复习总结3;
- Pytorch 深度学习计算 部分 见 Pytorch 复习总结4;
- Pytorch 卷积神经网络 部分 见 Pytorch 复习总结5;
- Pytorch 现代卷积神经网络 部分 见 Pytorch 复习总结6;
目录
- 一. 多层感知机
- 1. 读取数据集
- 2. 神经网络模型
- 3. 激活函数
- 4. 损失函数
- 5. 优化器
- 6. 训练
- 二. 过拟合的缓解
- 1. 权重衰减
- 2. Dropout
一. 多层感知机
虽然线性模型易于实现和理解、计算成本低、泛化能力强,但是对于一些非线性问题,可能会违反线性模型的单调性。为此,多层感知器引入了隐藏层来克服线性模型的限制,并且加入激活函数以增强网络非线性建模能力。
1. 读取数据集
同 Pytorch 复习总结 2 中 Softmax 回归的数据读取,继续使用 Fashion-MNIST 图像分类数据集:
import torch
import torchvision
from torch.utils import data
from torchvision import transformsdef load_data_fashion_mnist(batch_size, resize=None):"""下载Fashion-MNIST数据集并将其加载到内存中"""trans = [transforms.ToTensor()]if resize:trans.insert(0, transforms.Resize(resize))trans = transforms.Compose(trans)mnist_train = torchvision.datasets.FashionMNIST(root="./data", train=True, transform=trans, download=True)mnist_test = torchvision.datasets.FashionMNIST(root="./data", train=False, transform=trans, download=True)return (data.DataLoader(mnist_train, batch_size, shuffle=True),data.DataLoader(mnist_test, batch_size, shuffle=False))batch_size = 256
train_iter, test_iter = load_data_fashion_mnist(batch_size)
2. 神经网络模型
先将输入的图像展平,然后使用 2 个全连接层进行处理,中间的全连接层需要使用激活函数激活,最后一层全连接层作为输出:
from torch import nn
net = nn.Sequential(nn.Flatten(),nn.Linear(784, 256),nn.ReLU(),nn.Linear(256, 10)
)
仍然使用 init_weights()
函数按正态分布初始化所有全连接层的权重:
def init_weights(m):if type(m) == nn.Linear:nn.init.normal_(m.weight, std=0.01)net.apply(init_weights)
3. 激活函数
上一节使用了 ReLU 函数进行激活,在实际应用中,还可以使用 sigmoid、tanh 等函数激活。ReLU、sigmoid、tanh 函数的梯度可视化如下:
import torch
from matplotlib import pyplot as pltx = torch.arange(-8.0, 8.0, 0.1, requires_grad=True)
# y = torch.relu(x)
# y = torch.sigmoid(x)
y = torch.tanh(x)
y.backward(torch.ones_like(x), retain_graph=True)
plt.figure(figsize=(5, 2.5))
plt.plot(x.detach(), x.grad)
plt.show()
4. 损失函数
同 Softmax 回归:
loss = nn.CrossEntropyLoss(reduction='none')
5. 优化器
同 Softmax 回归:
trainer = torch.optim.SGD(net.parameters(), lr=0.1)
6. 训练
同 Softmax 回归,可以将训练过程封装成函数:
def accuracy(y_hat, y):"""计算预测正确的数量"""if len(y_hat.shape) > 1 and y_hat.shape[1] > 1:y_hat = y_hat.argmax(axis=1)cmp = y_hat.type(y.dtype) == yreturn float(cmp.type(y.dtype).sum())def train_net(net, train_iter, test_iter, loss, num_epochs, trainer):for epoch in range(num_epochs): # 迭代训练轮次net.train() # 将模型设置为训练模式train_loss_sum = 0.0 # 训练损失总和train_acc_sum = 0.0 # 训练准确度总和sample_num = 0 # 样本数for X, y in train_iter:y_hat = net(X)l = loss(y_hat, y)trainer.zero_grad()l.mean().backward()trainer.step()train_loss_sum += l.sum()train_acc_sum += accuracy(y_hat, y)sample_num += y.numel()train_loss = train_loss_sum / sample_numtrain_acc = train_acc_sum / sample_numnet.eval() # 将模型设置为评估模式test_acc_sum = 0.0test_sample_num = 0for X, y in test_iter:test_acc_sum += accuracy(net(X), y)test_sample_num += y.numel()test_acc = test_acc_sum / test_sample_numprint(f'epoch {epoch + 1}, 'f'train loss {train_loss:.4f}, train acc {train_acc:.4f}, 'f'test acc {test_acc:.4f}')num_epochs = 10
train_net(net, train_iter, test_iter, loss, num_epochs, trainer)
二. 过拟合的缓解
当模型过于复杂、训练数据太少、迭代轮数太多时,就会出现过拟合现象。解决过拟合的方法有很多:
- 增加数据量:增加训练数据可以帮助模型更好地学习数据的真实规律,减少过拟合的发生;
- 简化模型:降低模型的复杂度,可以通过减少模型的参数数量、使用正则化等方法来实现;
- 交叉验证:使用交叉验证来评估模型的泛化能力,选择最优的模型;
- 提前停止:即 Dropout,在训练过程中监控模型在验证集上的表现,当验证集误差不再下降甚至开始上升时,及时停止训练,防止模型过拟合;
- 集成学习:使用集成学习方法(如随机森林、梯度提升树等)降低模型的方差,提高泛化能力。
下面介绍几种常用的正则化方法。
1. 权重衰减
权重衰减 (Weight Decay) 通过向损失函数中添加一个惩罚项来减小模型复杂度,以防止过拟合。惩罚项也叫 正则项,通常是权重的平方和(即 L2 范数)或权重的绝对值和(即 L1 范数)乘以一个正则化系数。
以线性回归的损失函数 L ( w , b ) L(\mathbf{w}, b) L(w,b) 为例,使用优化器训练时,在损失函数 L ( w , b ) L(\mathbf{w}, b) L(w,b) 上添加 L2 范数如下:
L ( w , b ) + λ 2 ∥ w ∥ 2 = 1 n ∑ i = 1 n 1 2 ( w ⊤ x ( i ) + b − y ( i ) ) 2 + λ 2 ∥ w ∥ 2 L(\mathbf{w}, b)+\frac{\lambda}{2}\|\mathbf{w}\|^2\\ =\frac{1}{n} \sum_{i=1}^n \frac{1}{2}\left(\mathbf{w}^{\top} \mathbf{x}^{(i)}+b-y^{(i)}\right)^2+\frac{\lambda}{2}\|\mathbf{w}\|^2\\ L(w,b)+2λ∥w∥2=n1i=1∑n21(w⊤x(i)+b−y(i))2+2λ∥w∥2
损失函数中没有添加偏置 b b b 的惩罚项,因为一般情况下,网络输出层的偏置项不需要正则化。代入 w \mathbf{w} w 的参数更新表达式为:
w ← ( 1 − η λ ) w − η ∣ B ∣ ∑ i ∈ B x ( i ) ( w ⊤ x ( i ) + b − y ( i ) ) \mathbf{w} \leftarrow(1-\eta \lambda) \mathbf{w}-\frac{\eta}{|\mathcal{B}|} \sum_{i \in \mathcal{B}} \mathbf{x}^{(i)}\left(\mathbf{w}^{\top} \mathbf{x}^{(i)}+b-y^{(i)}\right) w←(1−ηλ)w−∣B∣ηi∈B∑x(i)(w⊤x(i)+b−y(i))
要想对模型进行权重衰减,只需要在实例化优化器时通过 weight_decay
指定权重衰减参数。默认情况下,PyTorch 同时衰减权重和偏移:
trainer = torch.optim.SGD(net.parameters(), lr=lr)
如果想要只衰减权重,需要指定参数:
params_to_optimize = [{"params": net[0].weight, 'weight_decay': wd},{"params":net[0].bias}
]
trainer = torch.optim.SGD([{"params":net[0].weight,'weight_decay': wd},{"params":net[0].bias}], lr=lr)
2. Dropout
Dropout 通过在训练过程中随机地将网络 内部 的一部分神经元的输出设置为零,即以一定的概率 “丢弃” 这些神经元。这样可以防止神经元在训练过程中过于依赖其他神经元,从而降低了网络对特定神经元的依赖性,使得网络更具鲁棒性:
通常情况下,Dropout 只在训练过程中使用,不在推理阶段使用,因为推理时模型需要产生确定性的输出。
Dropout 需要在网络中添加 Dropout 层,一般位于激活函数后,并且给定 dropout 概率:
dropout1, dropout2 = 0.2, 0.5net = nn.Sequential(nn.Flatten(),nn.Linear(784, 256),nn.ReLU(),nn.Dropout(dropout1),nn.Linear(256, 256),nn.ReLU(),nn.Dropout(dropout2),nn.Linear(256, 10)
)def init_weights(m):if type(m) == nn.Linear:nn.init.normal_(m.weight, std=0.01)net.apply(init_weights)
Dropout 概率的设置技巧是靠近输入层的地方设置较低的概率,远离输入层的地方设置较高的概率。
相关文章:

Pytorch 复习总结 3
Pytorch 复习总结,仅供笔者使用,参考教材: 《动手学深度学习》Stanford University: Practical Machine Learning 本文主要内容为:Pytorch 多层感知机。 本文先介绍了多层感知机的用法,再就训练过程中经常出现的过拟…...

2024年危险化学品经营单位主要负责人证考试题库及危险化学品经营单位主要负责人试题解析
题库来源:安全生产模拟考试一点通公众号小程序 2024年危险化学品经营单位主要负责人证考试题库及危险化学品经营单位主要负责人试题解析是安全生产模拟考试一点通结合(安监局)特种作业人员操作证考试大纲和(质检局)特…...

go使用trpc案例
1.go下载trpc go install trpc.group/trpc-go/trpc-cmdline/trpclatest 有报错的话尝试配置一些代理(选一个) go env -w GOPROXYhttps://goproxy.cn,direct go env -w GOPROXYhttps://goproxy.io,direct go env -w GOPROXYhttps://goproxy.baidu.com/…...

nodejs+vue+ElementUi废品废弃资源回收系统
系统主要是以后台管理员管理为主。管理员需要先登录系统然后才可以使用本系统,管理员可以对系统用户管理、用户信息管理、回收站点管理、站点分类管理、站点分类管理、留言板管理、系统管理进行添加、查询、修改、删除,以保障废弃资源回收系统系统的正常…...

【Java程序设计】【C00277】基于Springboot的招生管理系统(有论文)
基于Springboot的招生管理系统(有论文) 项目简介项目获取开发环境项目技术运行截图 项目简介 这是一个基于Springboot的招生管理系统 本系统分为系统功能模块、管理员功能模块以及学生功能模块。 系统功能模块:在系统首页可以查看首页、专业…...

汇编语言与接口技术实践——秒表
1. 设计要求 基于 51 开发板,利用键盘作为按键输入,将数码管作为显示输出,实现电子秒表。 功能要求: (1)计时精度达到百分之一秒; (2)能按键记录下5次时间并通过按键回看 (3)设置时间,实现倒计时,时间到,数码管闪烁 10 次,并激发蜂鸣器,可通过按键解除。 2. 设计思…...

【数据结构与算法】(19)高级数据结构与算法设计之 图 拓扑排序 最短路径 最小生成树 不相交集合(并查集合)代码示例
目录 6) 拓扑排序KahnDFS 7) 最短路径DijkstraBellman-FordFloyd-Warshall 8) 最小生成树PrimKruskal 9) 不相交集合(并查集合)基础路径压缩Union By Size 图-相关题目 6) 拓扑排序 #mermaid-svg-MQhLsXiMwnlUL3q4 {font-family:"trebuchet ms"…...

OSCP靶场--Nickel
OSCP靶场–Nickel 考点(1.POST方法请求信息 2.ftp,ssh密码复用 3.pdf文件密码爆破) 1.nmap扫描 ┌──(root㉿kali)-[~/Desktop] └─# nmap 192.168.237.99 -sV -sC -p- --min-rate 5000 Starting Nmap 7.92 ( https://nmap.org ) at 2024-02-22 04:06 EST Nm…...

新建工程——库函数版
新建工程——库函数版 s t e p I : 新建工程文件夹 \bf{stepI:新建工程文件夹} stepI:新建工程文件夹 s t e p I I : K e i l 5 新建工程 \bf{stepII:Keil5新建工程} stepII:Keil5新建工程 s t e p I I I : 最终得到工程文件 \bf{stepIII:最终得到工程文件} stepIII:最终得到工…...

java 数据结构栈和队列
目录 栈(Stack) 栈的使用 栈的模拟实现 栈的应用场景 队列(Queue) 队列的使用 队列模拟实现 循环队列 双端队列 用队列实现栈 用栈实现队列 栈(Stack) 什么是栈? 栈 :一种特殊的线性表,其 只允许在固定的一端进行插入和删除元素操…...

#LLM入门|Prompt#1.8_聊天机器人_Chatbot
聊天机器人设计 以会话形式进行交互,接受一系列消息作为输入,并返回模型生成的消息作为输出。原本设计用于简便多轮对话,但同样适用于单轮任务。 设计思路 个性化特性:通过定制模型的训练数据和参数,使机器人拥有特…...

LeetCode 2476.二叉搜索树最近节点查询:中序遍历 + 二分查找
【LetMeFly】2476.二叉搜索树最近节点查询:中序遍历 二分查找 力扣题目链接:https://leetcode.cn/problems/closest-nodes-queries-in-a-binary-search-tree/ 给你一个 二叉搜索树 的根节点 root ,和一个由正整数组成、长度为 n 的数组 qu…...

选座位 - 华为OD统一考试(C卷)
OD统一考试(C卷) 分值: 200分 题解: Java / Python / C 题目描述 疫情期间,需要大家保证一定的社交距离,公司组织开交流会议,座位有一排共N个座位,编号分别为[0…N-1],要…...

【微服务】mybatis typehandler使用详解
目录 一、前言 二、TypeHandler简介 2.1 什么是TypeHandler 2.1.1 TypeHandler特点 2.2 TypeHandler原理 2.3 mybatis自带的TypeHandler 三、环境准备 3.1 准备一张数据表 3.2 搭建一个springboot工程 3.2.1 基础依赖如下 3.2.2 核心配置文件 3.2.3 测试接口 四、T…...

计网 - 深入理解HTTPS:加密技术的背后
文章目录 Pre发展历史Http VS HttpsHTTPS 解决了 HTTP 的哪些问题HTTPS是如何解决上述三个风险的混合加密摘要算法 数字签名数字证书 Pre PKI - 数字签名与数字证书 PKI - 借助Nginx 实现Https 服务端单向认证、服务端客户端双向认证 发展历史 HTTP(超文本传输协…...

Jmeter之单接口的性能测试
前言: 服务端的整体性能测试是一个非常复杂的概念,包含生成虚拟用户,模拟并发,分析性能结果等各种技术,期间可能还要解决设计场景、缓存影响、第三方接口mock、IP限制等问题。如何用有限的测试机器,在测试环…...
成像光谱遥感技术中的AI革命:ChatGPT应用指南
“成像光谱遥感技术中的人工智能革命:ChatGPT应用指南”,这是一门旨在改变您使用人工智能处理遥感数据的方式。将最新的人工智能技术与实际的遥感应用相结合,提供不仅是理论上的,而且是适用和可靠的工具和方法。无论你是经验丰富的…...

掌握BeautifulSoup4:爬虫解析器的基础与实战【第91篇—BeautifulSoup4】
掌握BeautifulSoup4:爬虫解析器的基础与实战 网络上的信息浩如烟海,而爬虫技术正是帮助我们从中获取有用信息的重要工具。在爬虫过程中,解析HTML页面是一个关键步骤,而BeautifulSoup4正是一款功能强大的解析器,能够轻…...

从源码解析Kruise(K8S)原地升级原理
从源码解析Kruise原地升级原理 本文从源码的角度分析 Kruise 原地升级相关功能的实现。 本篇Kruise版本为v1.5.2。 Kruise项目地址: https://github.com/openkruise/kruise 更多云原生、K8S相关文章请点击【专栏】查看! 原地升级的概念 当我们使用deployment等Wor…...

2024年【广东省安全员C证第四批(专职安全生产管理人员)】复审考试及广东省安全员C证第四批(专职安全生产管理人员)模拟考试题
题库来源:安全生产模拟考试一点通公众号小程序 广东省安全员C证第四批(专职安全生产管理人员)复审考试是安全生产模拟考试一点通总题库中生成的一套广东省安全员C证第四批(专职安全生产管理人员)模拟考试题࿰…...

【Axure高保真原型】引导弹窗
今天和大家中分享引导弹窗的原型模板,载入页面后,会显示引导弹窗,适用于引导用户使用页面,点击完成后,会显示下一个引导弹窗,直至最后一个引导弹窗完成后进入首页。具体效果可以点击下方视频观看或打开下方…...

Docker 离线安装指南
参考文章 1、确认操作系统类型及内核版本 Docker依赖于Linux内核的一些特性,不同版本的Docker对内核版本有不同要求。例如,Docker 17.06及之后的版本通常需要Linux内核3.10及以上版本,Docker17.09及更高版本对应Linux内核4.9.x及更高版本。…...

51c自动驾驶~合集58
我自己的原文哦~ https://blog.51cto.com/whaosoft/13967107 #CCA-Attention 全局池化局部保留,CCA-Attention为LLM长文本建模带来突破性进展 琶洲实验室、华南理工大学联合推出关键上下文感知注意力机制(CCA-Attention),…...
数据链路层的主要功能是什么
数据链路层(OSI模型第2层)的核心功能是在相邻网络节点(如交换机、主机)间提供可靠的数据帧传输服务,主要职责包括: 🔑 核心功能详解: 帧封装与解封装 封装: 将网络层下发…...
TRS收益互换:跨境资本流动的金融创新工具与系统化解决方案
一、TRS收益互换的本质与业务逻辑 (一)概念解析 TRS(Total Return Swap)收益互换是一种金融衍生工具,指交易双方约定在未来一定期限内,基于特定资产或指数的表现进行现金流交换的协议。其核心特征包括&am…...

C# 求圆面积的程序(Program to find area of a circle)
给定半径r,求圆的面积。圆的面积应精确到小数点后5位。 例子: 输入:r 5 输出:78.53982 解释:由于面积 PI * r * r 3.14159265358979323846 * 5 * 5 78.53982,因为我们只保留小数点后 5 位数字。 输…...
Java + Spring Boot + Mybatis 实现批量插入
在 Java 中使用 Spring Boot 和 MyBatis 实现批量插入可以通过以下步骤完成。这里提供两种常用方法:使用 MyBatis 的 <foreach> 标签和批处理模式(ExecutorType.BATCH)。 方法一:使用 XML 的 <foreach> 标签ÿ…...
比较数据迁移后MySQL数据库和OceanBase数据仓库中的表
设计一个MySQL数据库和OceanBase数据仓库的表数据比较的详细程序流程,两张表是相同的结构,都有整型主键id字段,需要每次从数据库分批取得2000条数据,用于比较,比较操作的同时可以再取2000条数据,等上一次比较完成之后,开始比较,直到比较完所有的数据。比较操作需要比较…...

pikachu靶场通关笔记19 SQL注入02-字符型注入(GET)
目录 一、SQL注入 二、字符型SQL注入 三、字符型注入与数字型注入 四、源码分析 五、渗透实战 1、渗透准备 2、SQL注入探测 (1)输入单引号 (2)万能注入语句 3、获取回显列orderby 4、获取数据库名database 5、获取表名…...
uniapp 集成腾讯云 IM 富媒体消息(地理位置/文件)
UniApp 集成腾讯云 IM 富媒体消息全攻略(地理位置/文件) 一、功能实现原理 腾讯云 IM 通过 消息扩展机制 支持富媒体类型,核心实现方式: 标准消息类型:直接使用 SDK 内置类型(文件、图片等)自…...