当前位置: 首页 > news >正文

Pytorch 复习总结 3

Pytorch 复习总结,仅供笔者使用,参考教材:

  • 《动手学深度学习》
  • Stanford University: Practical Machine Learning

本文主要内容为:Pytorch 多层感知机。

本文先介绍了多层感知机的用法,再就训练过程中经常出现的过拟合现象提出解决办法。


Pytorch 语法汇总:

  • Pytorch 张量的常见运算、线性代数、高等数学、概率论 部分 见 Pytorch 复习总结1;
  • Pytorch 线性神经网络 部分 见 Pytorch 复习总结2;
  • Pytorch 多层感知机 部分 见 Pytorch 复习总结3;
  • Pytorch 深度学习计算 部分 见 Pytorch 复习总结4;
  • Pytorch 卷积神经网络 部分 见 Pytorch 复习总结5;
  • Pytorch 现代卷积神经网络 部分 见 Pytorch 复习总结6;

目录

  • 一. 多层感知机
    • 1. 读取数据集
    • 2. 神经网络模型
    • 3. 激活函数
    • 4. 损失函数
    • 5. 优化器
    • 6. 训练
  • 二. 过拟合的缓解
    • 1. 权重衰减
    • 2. Dropout

一. 多层感知机

虽然线性模型易于实现和理解、计算成本低、泛化能力强,但是对于一些非线性问题,可能会违反线性模型的单调性。为此,多层感知器引入了隐藏层来克服线性模型的限制,并且加入激活函数以增强网络非线性建模能力。

1. 读取数据集

同 Pytorch 复习总结 2 中 Softmax 回归的数据读取,继续使用 Fashion-MNIST 图像分类数据集:

import torch
import torchvision
from torch.utils import data
from torchvision import transformsdef load_data_fashion_mnist(batch_size, resize=None):"""下载Fashion-MNIST数据集并将其加载到内存中"""trans = [transforms.ToTensor()]if resize:trans.insert(0, transforms.Resize(resize))trans = transforms.Compose(trans)mnist_train = torchvision.datasets.FashionMNIST(root="./data", train=True, transform=trans, download=True)mnist_test = torchvision.datasets.FashionMNIST(root="./data", train=False, transform=trans, download=True)return (data.DataLoader(mnist_train, batch_size, shuffle=True),data.DataLoader(mnist_test, batch_size, shuffle=False))batch_size = 256
train_iter, test_iter = load_data_fashion_mnist(batch_size)

2. 神经网络模型

先将输入的图像展平,然后使用 2 个全连接层进行处理,中间的全连接层需要使用激活函数激活,最后一层全连接层作为输出:

from torch import nn
net = nn.Sequential(nn.Flatten(),nn.Linear(784, 256),nn.ReLU(),nn.Linear(256, 10)
)

仍然使用 init_weights() 函数按正态分布初始化所有全连接层的权重:

def init_weights(m):if type(m) == nn.Linear:nn.init.normal_(m.weight, std=0.01)net.apply(init_weights)

3. 激活函数

上一节使用了 ReLU 函数进行激活,在实际应用中,还可以使用 sigmoid、tanh 等函数激活。ReLU、sigmoid、tanh 函数的梯度可视化如下:

import torch
from matplotlib import pyplot as pltx = torch.arange(-8.0, 8.0, 0.1, requires_grad=True)
# y = torch.relu(x)
# y = torch.sigmoid(x)
y = torch.tanh(x)
y.backward(torch.ones_like(x), retain_graph=True)
plt.figure(figsize=(5, 2.5))
plt.plot(x.detach(), x.grad)
plt.show()

4. 损失函数

同 Softmax 回归:

loss = nn.CrossEntropyLoss(reduction='none')

5. 优化器

同 Softmax 回归:

trainer = torch.optim.SGD(net.parameters(), lr=0.1)

6. 训练

同 Softmax 回归,可以将训练过程封装成函数:

def accuracy(y_hat, y):"""计算预测正确的数量"""if len(y_hat.shape) > 1 and y_hat.shape[1] > 1:y_hat = y_hat.argmax(axis=1)cmp = y_hat.type(y.dtype) == yreturn float(cmp.type(y.dtype).sum())def train_net(net, train_iter, test_iter, loss, num_epochs, trainer):for epoch in range(num_epochs):     # 迭代训练轮次net.train()                     # 将模型设置为训练模式train_loss_sum = 0.0            # 训练损失总和train_acc_sum = 0.0             # 训练准确度总和sample_num = 0                  # 样本数for X, y in train_iter:y_hat = net(X)l = loss(y_hat, y)trainer.zero_grad()l.mean().backward()trainer.step()train_loss_sum += l.sum()train_acc_sum += accuracy(y_hat, y)sample_num += y.numel()train_loss = train_loss_sum / sample_numtrain_acc = train_acc_sum / sample_numnet.eval()                      # 将模型设置为评估模式test_acc_sum = 0.0test_sample_num = 0for X, y in test_iter:test_acc_sum += accuracy(net(X), y)test_sample_num += y.numel()test_acc = test_acc_sum / test_sample_numprint(f'epoch {epoch + 1}, 'f'train loss {train_loss:.4f}, train acc {train_acc:.4f}, 'f'test acc {test_acc:.4f}')num_epochs = 10
train_net(net, train_iter, test_iter, loss, num_epochs, trainer)

二. 过拟合的缓解

当模型过于复杂、训练数据太少、迭代轮数太多时,就会出现过拟合现象。解决过拟合的方法有很多:

  • 增加数据量:增加训练数据可以帮助模型更好地学习数据的真实规律,减少过拟合的发生;
  • 简化模型:降低模型的复杂度,可以通过减少模型的参数数量、使用正则化等方法来实现;
  • 交叉验证:使用交叉验证来评估模型的泛化能力,选择最优的模型;
  • 提前停止:即 Dropout,在训练过程中监控模型在验证集上的表现,当验证集误差不再下降甚至开始上升时,及时停止训练,防止模型过拟合;
  • 集成学习:使用集成学习方法(如随机森林、梯度提升树等)降低模型的方差,提高泛化能力。

下面介绍几种常用的正则化方法。

1. 权重衰减

权重衰减 (Weight Decay) 通过向损失函数中添加一个惩罚项来减小模型复杂度,以防止过拟合。惩罚项也叫 正则项,通常是权重的平方和(即 L2 范数)或权重的绝对值和(即 L1 范数)乘以一个正则化系数。

以线性回归的损失函数 L ( w , b ) L(\mathbf{w}, b) L(w,b) 为例,使用优化器训练时,在损失函数 L ( w , b ) L(\mathbf{w}, b) L(w,b) 上添加 L2 范数如下:
L ( w , b ) + λ 2 ∥ w ∥ 2 = 1 n ∑ i = 1 n 1 2 ( w ⊤ x ( i ) + b − y ( i ) ) 2 + λ 2 ∥ w ∥ 2 L(\mathbf{w}, b)+\frac{\lambda}{2}\|\mathbf{w}\|^2\\ =\frac{1}{n} \sum_{i=1}^n \frac{1}{2}\left(\mathbf{w}^{\top} \mathbf{x}^{(i)}+b-y^{(i)}\right)^2+\frac{\lambda}{2}\|\mathbf{w}\|^2\\ L(w,b)+2λw2=n1i=1n21(wx(i)+by(i))2+2λw2

损失函数中没有添加偏置 b b b 的惩罚项,因为一般情况下,网络输出层的偏置项不需要正则化。代入 w \mathbf{w} w 的参数更新表达式为:
w ← ( 1 − η λ ) w − η ∣ B ∣ ∑ i ∈ B x ( i ) ( w ⊤ x ( i ) + b − y ( i ) ) \mathbf{w} \leftarrow(1-\eta \lambda) \mathbf{w}-\frac{\eta}{|\mathcal{B}|} \sum_{i \in \mathcal{B}} \mathbf{x}^{(i)}\left(\mathbf{w}^{\top} \mathbf{x}^{(i)}+b-y^{(i)}\right) w(1ηλ)wBηiBx(i)(wx(i)+by(i))

要想对模型进行权重衰减,只需要在实例化优化器时通过 weight_decay 指定权重衰减参数。默认情况下,PyTorch 同时衰减权重和偏移:

trainer = torch.optim.SGD(net.parameters(), lr=lr)

如果想要只衰减权重,需要指定参数:

params_to_optimize = [{"params": net[0].weight, 'weight_decay': wd},{"params":net[0].bias}
]
trainer = torch.optim.SGD([{"params":net[0].weight,'weight_decay': wd},{"params":net[0].bias}], lr=lr)

2. Dropout

Dropout 通过在训练过程中随机地将网络 内部 的一部分神经元的输出设置为零,即以一定的概率 “丢弃” 这些神经元。这样可以防止神经元在训练过程中过于依赖其他神经元,从而降低了网络对特定神经元的依赖性,使得网络更具鲁棒性:
在这里插入图片描述

通常情况下,Dropout 只在训练过程中使用,不在推理阶段使用,因为推理时模型需要产生确定性的输出。

Dropout 需要在网络中添加 Dropout 层,一般位于激活函数后,并且给定 dropout 概率:

dropout1, dropout2 = 0.2, 0.5net = nn.Sequential(nn.Flatten(),nn.Linear(784, 256),nn.ReLU(),nn.Dropout(dropout1),nn.Linear(256, 256),nn.ReLU(),nn.Dropout(dropout2),nn.Linear(256, 10)
)def init_weights(m):if type(m) == nn.Linear:nn.init.normal_(m.weight, std=0.01)net.apply(init_weights)

Dropout 概率的设置技巧是靠近输入层的地方设置较低的概率,远离输入层的地方设置较高的概率。

相关文章:

Pytorch 复习总结 3

Pytorch 复习总结,仅供笔者使用,参考教材: 《动手学深度学习》Stanford University: Practical Machine Learning 本文主要内容为:Pytorch 多层感知机。 本文先介绍了多层感知机的用法,再就训练过程中经常出现的过拟…...

2024年危险化学品经营单位主要负责人证考试题库及危险化学品经营单位主要负责人试题解析

题库来源:安全生产模拟考试一点通公众号小程序 2024年危险化学品经营单位主要负责人证考试题库及危险化学品经营单位主要负责人试题解析是安全生产模拟考试一点通结合(安监局)特种作业人员操作证考试大纲和(质检局)特…...

go使用trpc案例

1.go下载trpc go install trpc.group/trpc-go/trpc-cmdline/trpclatest 有报错的话尝试配置一些代理(选一个) go env -w GOPROXYhttps://goproxy.cn,direct go env -w GOPROXYhttps://goproxy.io,direct go env -w GOPROXYhttps://goproxy.baidu.com/…...

nodejs+vue+ElementUi废品废弃资源回收系统

系统主要是以后台管理员管理为主。管理员需要先登录系统然后才可以使用本系统,管理员可以对系统用户管理、用户信息管理、回收站点管理、站点分类管理、站点分类管理、留言板管理、系统管理进行添加、查询、修改、删除,以保障废弃资源回收系统系统的正常…...

【Java程序设计】【C00277】基于Springboot的招生管理系统(有论文)

基于Springboot的招生管理系统(有论文) 项目简介项目获取开发环境项目技术运行截图 项目简介 这是一个基于Springboot的招生管理系统 本系统分为系统功能模块、管理员功能模块以及学生功能模块。 系统功能模块:在系统首页可以查看首页、专业…...

汇编语言与接口技术实践——秒表

1. 设计要求 基于 51 开发板,利用键盘作为按键输入,将数码管作为显示输出,实现电子秒表。 功能要求: (1)计时精度达到百分之一秒; (2)能按键记录下5次时间并通过按键回看 (3)设置时间,实现倒计时,时间到,数码管闪烁 10 次,并激发蜂鸣器,可通过按键解除。 2. 设计思…...

【数据结构与算法】(19)高级数据结构与算法设计之 图 拓扑排序 最短路径 最小生成树 不相交集合(并查集合)代码示例

目录 6) 拓扑排序KahnDFS 7) 最短路径DijkstraBellman-FordFloyd-Warshall 8) 最小生成树PrimKruskal 9) 不相交集合(并查集合)基础路径压缩Union By Size 图-相关题目 6) 拓扑排序 #mermaid-svg-MQhLsXiMwnlUL3q4 {font-family:"trebuchet ms"…...

OSCP靶场--Nickel

OSCP靶场–Nickel 考点(1.POST方法请求信息 2.ftp,ssh密码复用 3.pdf文件密码爆破) 1.nmap扫描 ┌──(root㉿kali)-[~/Desktop] └─# nmap 192.168.237.99 -sV -sC -p- --min-rate 5000 Starting Nmap 7.92 ( https://nmap.org ) at 2024-02-22 04:06 EST Nm…...

新建工程——库函数版

新建工程——库函数版 s t e p I : 新建工程文件夹 \bf{stepI:新建工程文件夹} stepI:新建工程文件夹 s t e p I I : K e i l 5 新建工程 \bf{stepII:Keil5新建工程} stepII:Keil5新建工程 s t e p I I I : 最终得到工程文件 \bf{stepIII:最终得到工程文件} stepIII:最终得到工…...

java 数据结构栈和队列

目录 栈(Stack) 栈的使用 栈的模拟实现 栈的应用场景 队列(Queue) 队列的使用 队列模拟实现 循环队列 双端队列 用队列实现栈 用栈实现队列 栈(Stack) 什么是栈? 栈 :一种特殊的线性表,其 只允许在固定的一端进行插入和删除元素操…...

#LLM入门|Prompt#1.8_聊天机器人_Chatbot

聊天机器人设计 以会话形式进行交互,接受一系列消息作为输入,并返回模型生成的消息作为输出。原本设计用于简便多轮对话,但同样适用于单轮任务。 设计思路 个性化特性:通过定制模型的训练数据和参数,使机器人拥有特…...

LeetCode 2476.二叉搜索树最近节点查询:中序遍历 + 二分查找

【LetMeFly】2476.二叉搜索树最近节点查询:中序遍历 二分查找 力扣题目链接:https://leetcode.cn/problems/closest-nodes-queries-in-a-binary-search-tree/ 给你一个 二叉搜索树 的根节点 root ,和一个由正整数组成、长度为 n 的数组 qu…...

选座位 - 华为OD统一考试(C卷)

OD统一考试(C卷) 分值: 200分 题解: Java / Python / C 题目描述 疫情期间,需要大家保证一定的社交距离,公司组织开交流会议,座位有一排共N个座位,编号分别为[0…N-1],要…...

【微服务】mybatis typehandler使用详解

目录 一、前言 二、TypeHandler简介 2.1 什么是TypeHandler 2.1.1 TypeHandler特点 2.2 TypeHandler原理 2.3 mybatis自带的TypeHandler 三、环境准备 3.1 准备一张数据表 3.2 搭建一个springboot工程 3.2.1 基础依赖如下 3.2.2 核心配置文件 3.2.3 测试接口 四、T…...

计网 - 深入理解HTTPS:加密技术的背后

文章目录 Pre发展历史Http VS HttpsHTTPS 解决了 HTTP 的哪些问题HTTPS是如何解决上述三个风险的混合加密摘要算法 数字签名数字证书 Pre PKI - 数字签名与数字证书 PKI - 借助Nginx 实现Https 服务端单向认证、服务端客户端双向认证 发展历史 HTTP(超文本传输协…...

Jmeter之单接口的性能测试

前言: 服务端的整体性能测试是一个非常复杂的概念,包含生成虚拟用户,模拟并发,分析性能结果等各种技术,期间可能还要解决设计场景、缓存影响、第三方接口mock、IP限制等问题。如何用有限的测试机器,在测试环…...

成像光谱遥感技术中的AI革命:ChatGPT应用指南

“成像光谱遥感技术中的人工智能革命:ChatGPT应用指南”,这是一门旨在改变您使用人工智能处理遥感数据的方式。将最新的人工智能技术与实际的遥感应用相结合,提供不仅是理论上的,而且是适用和可靠的工具和方法。无论你是经验丰富的…...

掌握BeautifulSoup4:爬虫解析器的基础与实战【第91篇—BeautifulSoup4】

掌握BeautifulSoup4:爬虫解析器的基础与实战 网络上的信息浩如烟海,而爬虫技术正是帮助我们从中获取有用信息的重要工具。在爬虫过程中,解析HTML页面是一个关键步骤,而BeautifulSoup4正是一款功能强大的解析器,能够轻…...

从源码解析Kruise(K8S)原地升级原理

从源码解析Kruise原地升级原理 本文从源码的角度分析 Kruise 原地升级相关功能的实现。 本篇Kruise版本为v1.5.2。 Kruise项目地址: https://github.com/openkruise/kruise 更多云原生、K8S相关文章请点击【专栏】查看! 原地升级的概念 当我们使用deployment等Wor…...

2024年【广东省安全员C证第四批(专职安全生产管理人员)】复审考试及广东省安全员C证第四批(专职安全生产管理人员)模拟考试题

题库来源:安全生产模拟考试一点通公众号小程序 广东省安全员C证第四批(专职安全生产管理人员)复审考试是安全生产模拟考试一点通总题库中生成的一套广东省安全员C证第四批(专职安全生产管理人员)模拟考试题&#xff0…...

Unity开发HoloLens应用:从打包到安装的完整避坑指南(2024最新版)

Unity开发HoloLens应用:从打包到安装的完整避坑指南(2024最新版) 如果你正在尝试将Unity项目部署到HoloLens设备上,可能会遇到各种意想不到的问题。作为一位经历过无数次打包、部署、调试循环的开发者,我想分享一些实战…...

提升90%效率:OpenCore EFI自动化配置工具OpCore-Simplify实战指南

提升90%效率:OpenCore EFI自动化配置工具OpCore-Simplify实战指南 【免费下载链接】OpCore-Simplify A tool designed to simplify the creation of OpenCore EFI 项目地址: https://gitcode.com/GitHub_Trending/op/OpCore-Simplify 副标题:面向…...

自编码器在异常检测中的实战应用:以金融交易数据为例

自编码器在金融异常检测中的实战指南:从数据清洗到模型部署 金融交易数据中的异常行为检测一直是风险控制的核心环节。传统基于规则的系统难以应对日益复杂的欺诈模式,而自编码器这类无监督学习模型正在改变游戏规则。本文将带您从零构建一个完整的异常检…...

Scarab:基于Avalonia的跨平台空洞骑士模组管理器架构解析

Scarab:基于Avalonia的跨平台空洞骑士模组管理器架构解析 【免费下载链接】Scarab An installer for Hollow Knight mods written in Avalonia. 项目地址: https://gitcode.com/gh_mirrors/sc/Scarab Scarab是一款专为《空洞骑士》游戏设计的跨平台模组管理器…...

5分钟部署清华TurboDiffusion,视频生成加速100倍,小白也能玩转AI视频

5分钟部署清华TurboDiffusion,视频生成加速100倍,小白也能玩转AI视频 1. TurboDiffusion技术背景与核心价值 1.1 技术发展历程 TurboDiffusion是由清华大学等机构联合推出的视频生成加速框架。该框架解决了传统扩散模型在视频生成过程中存在的计算效率…...

Lingyuxiu MXJ LoRA效果展示:masterpiece+best quality+8k三重加持高清输出

Lingyuxiu MXJ LoRA效果展示:masterpiecebest quality8k三重加持高清输出 1. 引言:当唯美人像遇上AI创作 想象一下,你是一位摄影师或设计师,需要创作一组具有特定艺术风格的人像作品。传统的流程需要寻找模特、布置灯光、后期精…...

RexUniNLU案例集:制造业设备报修场景中,‘异响’‘漏油’‘停机’故障标签识别效果

RexUniNLU案例集:制造业设备报修场景中,‘异响’‘漏油’‘停机’故障标签识别效果 1. 引言:当设备“说话”时,我们如何听懂? 想象一下这个场景:在一条繁忙的生产线上,一台关键设备突然发出“…...

5个高效能的LabelImg图像标注效率提升实践

5个高效能的LabelImg图像标注效率提升实践 【免费下载链接】labelImg LabelImg is now part of the Label Studio community. The popular image annotation tool created by Tzutalin is no longer actively being developed, but you can check out Label Studio, the open s…...

SUPER COLORIZER一键部署指南:基于Ubuntu 20.04的完整环境配置教程

SUPER COLORIZER一键部署指南:基于Ubuntu 20.04的完整环境配置教程 你是不是也遇到过一些珍贵的老照片,因为年代久远而褪色,想恢复它原本的色彩却无从下手?或者,你有一些黑白的设计稿,想快速预览上色后的效…...

Ollama部署LFM2.5-1.2B-Thinking:轻量模型在边缘设备上的真实性能报告

Ollama部署LFM2.5-1.2B-Thinking:轻量模型在边缘设备上的真实性能报告 1. 模型介绍:专为边缘设备设计的智能助手 LFM2.5-1.2B-Thinking是一个专门为设备端部署优化的文本生成模型,它在LFM2架构基础上进行了深度改进。这个模型最大的特点就是…...