当前位置: 首页 > news >正文

Pytorch 复习总结 3

Pytorch 复习总结,仅供笔者使用,参考教材:

  • 《动手学深度学习》
  • Stanford University: Practical Machine Learning

本文主要内容为:Pytorch 多层感知机。

本文先介绍了多层感知机的用法,再就训练过程中经常出现的过拟合现象提出解决办法。


Pytorch 语法汇总:

  • Pytorch 张量的常见运算、线性代数、高等数学、概率论 部分 见 Pytorch 复习总结1;
  • Pytorch 线性神经网络 部分 见 Pytorch 复习总结2;
  • Pytorch 多层感知机 部分 见 Pytorch 复习总结3;
  • Pytorch 深度学习计算 部分 见 Pytorch 复习总结4;
  • Pytorch 卷积神经网络 部分 见 Pytorch 复习总结5;
  • Pytorch 现代卷积神经网络 部分 见 Pytorch 复习总结6;

目录

  • 一. 多层感知机
    • 1. 读取数据集
    • 2. 神经网络模型
    • 3. 激活函数
    • 4. 损失函数
    • 5. 优化器
    • 6. 训练
  • 二. 过拟合的缓解
    • 1. 权重衰减
    • 2. Dropout

一. 多层感知机

虽然线性模型易于实现和理解、计算成本低、泛化能力强,但是对于一些非线性问题,可能会违反线性模型的单调性。为此,多层感知器引入了隐藏层来克服线性模型的限制,并且加入激活函数以增强网络非线性建模能力。

1. 读取数据集

同 Pytorch 复习总结 2 中 Softmax 回归的数据读取,继续使用 Fashion-MNIST 图像分类数据集:

import torch
import torchvision
from torch.utils import data
from torchvision import transformsdef load_data_fashion_mnist(batch_size, resize=None):"""下载Fashion-MNIST数据集并将其加载到内存中"""trans = [transforms.ToTensor()]if resize:trans.insert(0, transforms.Resize(resize))trans = transforms.Compose(trans)mnist_train = torchvision.datasets.FashionMNIST(root="./data", train=True, transform=trans, download=True)mnist_test = torchvision.datasets.FashionMNIST(root="./data", train=False, transform=trans, download=True)return (data.DataLoader(mnist_train, batch_size, shuffle=True),data.DataLoader(mnist_test, batch_size, shuffle=False))batch_size = 256
train_iter, test_iter = load_data_fashion_mnist(batch_size)

2. 神经网络模型

先将输入的图像展平,然后使用 2 个全连接层进行处理,中间的全连接层需要使用激活函数激活,最后一层全连接层作为输出:

from torch import nn
net = nn.Sequential(nn.Flatten(),nn.Linear(784, 256),nn.ReLU(),nn.Linear(256, 10)
)

仍然使用 init_weights() 函数按正态分布初始化所有全连接层的权重:

def init_weights(m):if type(m) == nn.Linear:nn.init.normal_(m.weight, std=0.01)net.apply(init_weights)

3. 激活函数

上一节使用了 ReLU 函数进行激活,在实际应用中,还可以使用 sigmoid、tanh 等函数激活。ReLU、sigmoid、tanh 函数的梯度可视化如下:

import torch
from matplotlib import pyplot as pltx = torch.arange(-8.0, 8.0, 0.1, requires_grad=True)
# y = torch.relu(x)
# y = torch.sigmoid(x)
y = torch.tanh(x)
y.backward(torch.ones_like(x), retain_graph=True)
plt.figure(figsize=(5, 2.5))
plt.plot(x.detach(), x.grad)
plt.show()

4. 损失函数

同 Softmax 回归:

loss = nn.CrossEntropyLoss(reduction='none')

5. 优化器

同 Softmax 回归:

trainer = torch.optim.SGD(net.parameters(), lr=0.1)

6. 训练

同 Softmax 回归,可以将训练过程封装成函数:

def accuracy(y_hat, y):"""计算预测正确的数量"""if len(y_hat.shape) > 1 and y_hat.shape[1] > 1:y_hat = y_hat.argmax(axis=1)cmp = y_hat.type(y.dtype) == yreturn float(cmp.type(y.dtype).sum())def train_net(net, train_iter, test_iter, loss, num_epochs, trainer):for epoch in range(num_epochs):     # 迭代训练轮次net.train()                     # 将模型设置为训练模式train_loss_sum = 0.0            # 训练损失总和train_acc_sum = 0.0             # 训练准确度总和sample_num = 0                  # 样本数for X, y in train_iter:y_hat = net(X)l = loss(y_hat, y)trainer.zero_grad()l.mean().backward()trainer.step()train_loss_sum += l.sum()train_acc_sum += accuracy(y_hat, y)sample_num += y.numel()train_loss = train_loss_sum / sample_numtrain_acc = train_acc_sum / sample_numnet.eval()                      # 将模型设置为评估模式test_acc_sum = 0.0test_sample_num = 0for X, y in test_iter:test_acc_sum += accuracy(net(X), y)test_sample_num += y.numel()test_acc = test_acc_sum / test_sample_numprint(f'epoch {epoch + 1}, 'f'train loss {train_loss:.4f}, train acc {train_acc:.4f}, 'f'test acc {test_acc:.4f}')num_epochs = 10
train_net(net, train_iter, test_iter, loss, num_epochs, trainer)

二. 过拟合的缓解

当模型过于复杂、训练数据太少、迭代轮数太多时,就会出现过拟合现象。解决过拟合的方法有很多:

  • 增加数据量:增加训练数据可以帮助模型更好地学习数据的真实规律,减少过拟合的发生;
  • 简化模型:降低模型的复杂度,可以通过减少模型的参数数量、使用正则化等方法来实现;
  • 交叉验证:使用交叉验证来评估模型的泛化能力,选择最优的模型;
  • 提前停止:即 Dropout,在训练过程中监控模型在验证集上的表现,当验证集误差不再下降甚至开始上升时,及时停止训练,防止模型过拟合;
  • 集成学习:使用集成学习方法(如随机森林、梯度提升树等)降低模型的方差,提高泛化能力。

下面介绍几种常用的正则化方法。

1. 权重衰减

权重衰减 (Weight Decay) 通过向损失函数中添加一个惩罚项来减小模型复杂度,以防止过拟合。惩罚项也叫 正则项,通常是权重的平方和(即 L2 范数)或权重的绝对值和(即 L1 范数)乘以一个正则化系数。

以线性回归的损失函数 L ( w , b ) L(\mathbf{w}, b) L(w,b) 为例,使用优化器训练时,在损失函数 L ( w , b ) L(\mathbf{w}, b) L(w,b) 上添加 L2 范数如下:
L ( w , b ) + λ 2 ∥ w ∥ 2 = 1 n ∑ i = 1 n 1 2 ( w ⊤ x ( i ) + b − y ( i ) ) 2 + λ 2 ∥ w ∥ 2 L(\mathbf{w}, b)+\frac{\lambda}{2}\|\mathbf{w}\|^2\\ =\frac{1}{n} \sum_{i=1}^n \frac{1}{2}\left(\mathbf{w}^{\top} \mathbf{x}^{(i)}+b-y^{(i)}\right)^2+\frac{\lambda}{2}\|\mathbf{w}\|^2\\ L(w,b)+2λw2=n1i=1n21(wx(i)+by(i))2+2λw2

损失函数中没有添加偏置 b b b 的惩罚项,因为一般情况下,网络输出层的偏置项不需要正则化。代入 w \mathbf{w} w 的参数更新表达式为:
w ← ( 1 − η λ ) w − η ∣ B ∣ ∑ i ∈ B x ( i ) ( w ⊤ x ( i ) + b − y ( i ) ) \mathbf{w} \leftarrow(1-\eta \lambda) \mathbf{w}-\frac{\eta}{|\mathcal{B}|} \sum_{i \in \mathcal{B}} \mathbf{x}^{(i)}\left(\mathbf{w}^{\top} \mathbf{x}^{(i)}+b-y^{(i)}\right) w(1ηλ)wBηiBx(i)(wx(i)+by(i))

要想对模型进行权重衰减,只需要在实例化优化器时通过 weight_decay 指定权重衰减参数。默认情况下,PyTorch 同时衰减权重和偏移:

trainer = torch.optim.SGD(net.parameters(), lr=lr)

如果想要只衰减权重,需要指定参数:

params_to_optimize = [{"params": net[0].weight, 'weight_decay': wd},{"params":net[0].bias}
]
trainer = torch.optim.SGD([{"params":net[0].weight,'weight_decay': wd},{"params":net[0].bias}], lr=lr)

2. Dropout

Dropout 通过在训练过程中随机地将网络 内部 的一部分神经元的输出设置为零,即以一定的概率 “丢弃” 这些神经元。这样可以防止神经元在训练过程中过于依赖其他神经元,从而降低了网络对特定神经元的依赖性,使得网络更具鲁棒性:
在这里插入图片描述

通常情况下,Dropout 只在训练过程中使用,不在推理阶段使用,因为推理时模型需要产生确定性的输出。

Dropout 需要在网络中添加 Dropout 层,一般位于激活函数后,并且给定 dropout 概率:

dropout1, dropout2 = 0.2, 0.5net = nn.Sequential(nn.Flatten(),nn.Linear(784, 256),nn.ReLU(),nn.Dropout(dropout1),nn.Linear(256, 256),nn.ReLU(),nn.Dropout(dropout2),nn.Linear(256, 10)
)def init_weights(m):if type(m) == nn.Linear:nn.init.normal_(m.weight, std=0.01)net.apply(init_weights)

Dropout 概率的设置技巧是靠近输入层的地方设置较低的概率,远离输入层的地方设置较高的概率。

相关文章:

Pytorch 复习总结 3

Pytorch 复习总结,仅供笔者使用,参考教材: 《动手学深度学习》Stanford University: Practical Machine Learning 本文主要内容为:Pytorch 多层感知机。 本文先介绍了多层感知机的用法,再就训练过程中经常出现的过拟…...

2024年危险化学品经营单位主要负责人证考试题库及危险化学品经营单位主要负责人试题解析

题库来源:安全生产模拟考试一点通公众号小程序 2024年危险化学品经营单位主要负责人证考试题库及危险化学品经营单位主要负责人试题解析是安全生产模拟考试一点通结合(安监局)特种作业人员操作证考试大纲和(质检局)特…...

go使用trpc案例

1.go下载trpc go install trpc.group/trpc-go/trpc-cmdline/trpclatest 有报错的话尝试配置一些代理(选一个) go env -w GOPROXYhttps://goproxy.cn,direct go env -w GOPROXYhttps://goproxy.io,direct go env -w GOPROXYhttps://goproxy.baidu.com/…...

nodejs+vue+ElementUi废品废弃资源回收系统

系统主要是以后台管理员管理为主。管理员需要先登录系统然后才可以使用本系统,管理员可以对系统用户管理、用户信息管理、回收站点管理、站点分类管理、站点分类管理、留言板管理、系统管理进行添加、查询、修改、删除,以保障废弃资源回收系统系统的正常…...

【Java程序设计】【C00277】基于Springboot的招生管理系统(有论文)

基于Springboot的招生管理系统(有论文) 项目简介项目获取开发环境项目技术运行截图 项目简介 这是一个基于Springboot的招生管理系统 本系统分为系统功能模块、管理员功能模块以及学生功能模块。 系统功能模块:在系统首页可以查看首页、专业…...

汇编语言与接口技术实践——秒表

1. 设计要求 基于 51 开发板,利用键盘作为按键输入,将数码管作为显示输出,实现电子秒表。 功能要求: (1)计时精度达到百分之一秒; (2)能按键记录下5次时间并通过按键回看 (3)设置时间,实现倒计时,时间到,数码管闪烁 10 次,并激发蜂鸣器,可通过按键解除。 2. 设计思…...

【数据结构与算法】(19)高级数据结构与算法设计之 图 拓扑排序 最短路径 最小生成树 不相交集合(并查集合)代码示例

目录 6) 拓扑排序KahnDFS 7) 最短路径DijkstraBellman-FordFloyd-Warshall 8) 最小生成树PrimKruskal 9) 不相交集合(并查集合)基础路径压缩Union By Size 图-相关题目 6) 拓扑排序 #mermaid-svg-MQhLsXiMwnlUL3q4 {font-family:"trebuchet ms"…...

OSCP靶场--Nickel

OSCP靶场–Nickel 考点(1.POST方法请求信息 2.ftp,ssh密码复用 3.pdf文件密码爆破) 1.nmap扫描 ┌──(root㉿kali)-[~/Desktop] └─# nmap 192.168.237.99 -sV -sC -p- --min-rate 5000 Starting Nmap 7.92 ( https://nmap.org ) at 2024-02-22 04:06 EST Nm…...

新建工程——库函数版

新建工程——库函数版 s t e p I : 新建工程文件夹 \bf{stepI:新建工程文件夹} stepI:新建工程文件夹 s t e p I I : K e i l 5 新建工程 \bf{stepII:Keil5新建工程} stepII:Keil5新建工程 s t e p I I I : 最终得到工程文件 \bf{stepIII:最终得到工程文件} stepIII:最终得到工…...

java 数据结构栈和队列

目录 栈(Stack) 栈的使用 栈的模拟实现 栈的应用场景 队列(Queue) 队列的使用 队列模拟实现 循环队列 双端队列 用队列实现栈 用栈实现队列 栈(Stack) 什么是栈? 栈 :一种特殊的线性表,其 只允许在固定的一端进行插入和删除元素操…...

#LLM入门|Prompt#1.8_聊天机器人_Chatbot

聊天机器人设计 以会话形式进行交互,接受一系列消息作为输入,并返回模型生成的消息作为输出。原本设计用于简便多轮对话,但同样适用于单轮任务。 设计思路 个性化特性:通过定制模型的训练数据和参数,使机器人拥有特…...

LeetCode 2476.二叉搜索树最近节点查询:中序遍历 + 二分查找

【LetMeFly】2476.二叉搜索树最近节点查询:中序遍历 二分查找 力扣题目链接:https://leetcode.cn/problems/closest-nodes-queries-in-a-binary-search-tree/ 给你一个 二叉搜索树 的根节点 root ,和一个由正整数组成、长度为 n 的数组 qu…...

选座位 - 华为OD统一考试(C卷)

OD统一考试(C卷) 分值: 200分 题解: Java / Python / C 题目描述 疫情期间,需要大家保证一定的社交距离,公司组织开交流会议,座位有一排共N个座位,编号分别为[0…N-1],要…...

【微服务】mybatis typehandler使用详解

目录 一、前言 二、TypeHandler简介 2.1 什么是TypeHandler 2.1.1 TypeHandler特点 2.2 TypeHandler原理 2.3 mybatis自带的TypeHandler 三、环境准备 3.1 准备一张数据表 3.2 搭建一个springboot工程 3.2.1 基础依赖如下 3.2.2 核心配置文件 3.2.3 测试接口 四、T…...

计网 - 深入理解HTTPS:加密技术的背后

文章目录 Pre发展历史Http VS HttpsHTTPS 解决了 HTTP 的哪些问题HTTPS是如何解决上述三个风险的混合加密摘要算法 数字签名数字证书 Pre PKI - 数字签名与数字证书 PKI - 借助Nginx 实现Https 服务端单向认证、服务端客户端双向认证 发展历史 HTTP(超文本传输协…...

Jmeter之单接口的性能测试

前言: 服务端的整体性能测试是一个非常复杂的概念,包含生成虚拟用户,模拟并发,分析性能结果等各种技术,期间可能还要解决设计场景、缓存影响、第三方接口mock、IP限制等问题。如何用有限的测试机器,在测试环…...

成像光谱遥感技术中的AI革命:ChatGPT应用指南

“成像光谱遥感技术中的人工智能革命:ChatGPT应用指南”,这是一门旨在改变您使用人工智能处理遥感数据的方式。将最新的人工智能技术与实际的遥感应用相结合,提供不仅是理论上的,而且是适用和可靠的工具和方法。无论你是经验丰富的…...

掌握BeautifulSoup4:爬虫解析器的基础与实战【第91篇—BeautifulSoup4】

掌握BeautifulSoup4:爬虫解析器的基础与实战 网络上的信息浩如烟海,而爬虫技术正是帮助我们从中获取有用信息的重要工具。在爬虫过程中,解析HTML页面是一个关键步骤,而BeautifulSoup4正是一款功能强大的解析器,能够轻…...

从源码解析Kruise(K8S)原地升级原理

从源码解析Kruise原地升级原理 本文从源码的角度分析 Kruise 原地升级相关功能的实现。 本篇Kruise版本为v1.5.2。 Kruise项目地址: https://github.com/openkruise/kruise 更多云原生、K8S相关文章请点击【专栏】查看! 原地升级的概念 当我们使用deployment等Wor…...

2024年【广东省安全员C证第四批(专职安全生产管理人员)】复审考试及广东省安全员C证第四批(专职安全生产管理人员)模拟考试题

题库来源:安全生产模拟考试一点通公众号小程序 广东省安全员C证第四批(专职安全生产管理人员)复审考试是安全生产模拟考试一点通总题库中生成的一套广东省安全员C证第四批(专职安全生产管理人员)模拟考试题&#xff0…...

conda相比python好处

Conda 作为 Python 的环境和包管理工具,相比原生 Python 生态(如 pip 虚拟环境)有许多独特优势,尤其在多项目管理、依赖处理和跨平台兼容性等方面表现更优。以下是 Conda 的核心好处: 一、一站式环境管理&#xff1a…...

springboot 百货中心供应链管理系统小程序

一、前言 随着我国经济迅速发展,人们对手机的需求越来越大,各种手机软件也都在被广泛应用,但是对于手机进行数据信息管理,对于手机的各种软件也是备受用户的喜爱,百货中心供应链管理系统被用户普遍使用,为方…...

理解 MCP 工作流:使用 Ollama 和 LangChain 构建本地 MCP 客户端

🌟 什么是 MCP? 模型控制协议 (MCP) 是一种创新的协议,旨在无缝连接 AI 模型与应用程序。 MCP 是一个开源协议,它标准化了我们的 LLM 应用程序连接所需工具和数据源并与之协作的方式。 可以把它想象成你的 AI 模型 和想要使用它…...

Qt Http Server模块功能及架构

Qt Http Server 是 Qt 6.0 中引入的一个新模块,它提供了一个轻量级的 HTTP 服务器实现,主要用于构建基于 HTTP 的应用程序和服务。 功能介绍: 主要功能 HTTP服务器功能: 支持 HTTP/1.1 协议 简单的请求/响应处理模型 支持 GET…...

Nginx server_name 配置说明

Nginx 是一个高性能的反向代理和负载均衡服务器,其核心配置之一是 server 块中的 server_name 指令。server_name 决定了 Nginx 如何根据客户端请求的 Host 头匹配对应的虚拟主机(Virtual Host)。 1. 简介 Nginx 使用 server_name 指令来确定…...

C++.OpenGL (10/64)基础光照(Basic Lighting)

基础光照(Basic Lighting) 冯氏光照模型(Phong Lighting Model) #mermaid-svg-GLdskXwWINxNGHso {font-family:"trebuchet ms",verdana,arial,sans-serif;font-size:16px;fill:#333;}#mermaid-svg-GLdskXwWINxNGHso .error-icon{fill:#552222;}#mermaid-svg-GLd…...

PL0语法,分析器实现!

简介 PL/0 是一种简单的编程语言,通常用于教学编译原理。它的语法结构清晰,功能包括常量定义、变量声明、过程(子程序)定义以及基本的控制结构(如条件语句和循环语句)。 PL/0 语法规范 PL/0 是一种教学用的小型编程语言,由 Niklaus Wirth 设计,用于展示编译原理的核…...

七、数据库的完整性

七、数据库的完整性 主要内容 7.1 数据库的完整性概述 7.2 实体完整性 7.3 参照完整性 7.4 用户定义的完整性 7.5 触发器 7.6 SQL Server中数据库完整性的实现 7.7 小结 7.1 数据库的完整性概述 数据库完整性的含义 正确性 指数据的合法性 有效性 指数据是否属于所定…...

uniapp 开发ios, xcode 提交app store connect 和 testflight内测

uniapp 中配置 配置manifest 文档:manifest.json 应用配置 | uni-app官网 hbuilderx中本地打包 下载IOS最新SDK 开发环境 | uni小程序SDK hbulderx 版本号:4.66 对应的sdk版本 4.66 两者必须一致 本地打包的资源导入到SDK 导入资源 | uni小程序SDK …...

uniapp 字符包含的相关方法

在uniapp中,如果你想检查一个字符串是否包含另一个子字符串,你可以使用JavaScript中的includes()方法或者indexOf()方法。这两种方法都可以达到目的,但它们在处理方式和返回值上有所不同。 使用includes()方法 includes()方法用于判断一个字…...