当前位置: 首页 > article >正文

训练与优化

训练与优化

损失函数与反向传播

损失函数能够衡量神经网络输出与目标值之间的误差,同时为反向传播提供依据,计算梯度来优化网络中的参数。

torch.nn.L1Loss 计算所有预测值与真实值之间的绝对差。参数为 reduction

  • 'none':不对损失进行任何求和或平均,返回每个元素的损失。
  • 'mean':对损失进行平均,默认选项。
  • 'sum':对所有样本的损失进行求和。
import torchinput = torch.tensor([1, 2, 3], dtype=torch.float32)
target = torch.tensor([1, 3, 5], dtype=torch.float32)loss = torch.nn.L1Loss(reduction="none")
res = loss(input, target)
print(res)
# tensor([0., 1., 2.])loss = torch.nn.L1Loss(reduction="mean")
res = loss(input, target)
print(res)
# tensor(1.)loss = torch.nn.L1Loss(reduction="sum")
res = loss(input, target)
print(res)
# tensor(3.)

torch.nn.MSELoss 计算每个样本的预测值与真实值之间的差距的平方,参数为 reduction

import torchinput = torch.tensor([1, 2, 3], dtype=torch.float32)
target = torch.tensor([1, 3, 5], dtype=torch.float32)loss = torch.nn.MSELoss(reduction="none")
res = loss(input, target)
print(res)
# tensor([0., 1., 4.])loss = torch.nn.MSELoss(reduction="mean")
res = loss(input, target)
print(res)
# tensor(1.6667)loss = torch.nn.MSELoss(reduction="sum")
res = loss(input, target)
print(res)
# tensor(5.)

torch.nn.CrossEntropyLoss 计算实际类别分布预测类别分布之间的差异。输入 input 为预测的类别得分(不是概率),维度为 (N,C) ,其中 N 是样本数量,C 是类别数量,每个样本是一个未经过softmax 的类别得分。真实标签索引 target 维度为 (N) ,每个标签是一个整数,表示该样本的真实类别索引。

CrossEntropyLoss自动计算 input 的 softmax ,然后根据交叉熵公式计算每个样本的损失。

import torch
from torch import nn# 2个样本,3个类别的得分
input = torch.tensor([[1, 2, 3], [1, 2, 3]], dtype=torch.float32)
# 真实标签:第1个样本属于类别2,第2个样本属于类别1
target = torch.tensor([2, 1])loss = nn.CrossEntropyLoss()res = loss(input, target)
print(res)
# tensor(0.9076)

如果数据集中的类别不平衡,可以通过 weight 参数对每个类别的损失进行加权。这样可以让模型在训练时更加关注某些类别。

import torch
from torch import nn# 2个样本,3个类别的得分
input = torch.tensor([[1, 2, 3], [1, 2, 3]], dtype=torch.float32)
# 真实标签:第1个样本属于类别2,第2个样本属于类别1
target = torch.tensor([2, 1])# 类别0权重为1,类别1权重为2,类别2权重为0.5
weight = torch.tensor([1.0, 2.0, 0.5])
loss = nn.CrossEntropyLoss(weight)res = loss(input, target)
print(res)
# tensor(1.2076)

当计算出损失函数后,便可计算出每一个节点参数的梯度,从而进行反向传播,只需要加上一行:

result_loss.backward()

训练与推理

在 PyTorch 中,神经网络的 train()eval() 模式控制着 Batch NormalizationDropout 这两类层的行为,确保模型在训练和推理(测试)时的表现一致。

model.train() 负责启动 BN 和 Dropout 层的训练模式。BatchNorm 会计算当前批次的均值和方差,用于归一化数据,这些均值和方差会随着训练逐步更新。Dropout 会随机丢弃一部分神经元,以减少过拟合。

model.eval() 负责关闭训练模式,进入推理模式,确保计算的均值、方差、Dropout 影响不会波动,保证结果稳定。计算归一化时,会使用训练期间学到的全局均值和方差,而不是当前批次的统计量。也不再随机丢弃神经元,而是使用完整的网络进行预测。

在训练的时候,还需要关闭梯度计算,减少内存占用,加速推理。因为推理时不需要计算梯度,不需要 backward() 进行反向传播。

with torch.no_grad():output = model(input)

train() 模式下,PyTorch 默认存储计算图,以支持 backward() 计算梯度torch.no_grad() 关闭计算图,避免存储不必要的梯度信息,减少显存占用。

训练模式

model.train()  # 训练模式(启用 BatchNorm 统计 和 Dropout)
for data in dataloader:optimizer.zero_grad()output = model(data)loss = loss_fn(output, target)loss.backward()optimizer.step()

推理模式

model.eval()  # 进入推理模式
with torch.no_grad():  # 关闭梯度计算output = model(input)

优化器

优化器利用通过反向传播计算得到的梯度来更新模型参数,从而减小损失函数值,提升模型的性能。

在每次训练过程中,首先使用 optimizer.zero_grad() 清零上一步的梯度,然后通过 loss.backward() 执行反向传播,计算当前模型参数的梯度,最后使用 optimizer.step() 根据梯度更新模型参数。

**SGD(随机梯度下降)**是基本的梯度下降法,每次更新一个小批量的数据(mini-batch)参数,需要调整学习率(lr)和可能的动量(momentum)等超参数。

Adam、Adagrad、Adadelta、RMSProp 是不同的优化算法,每种算法有不同的超参数调整方法,Adam自适应调整学习率,Adagrad适用于稀疏数据,Adadelta主要针对自适应学习率的调整。

学习速率不能太大(太大模型训练不稳定)也不能太小(太小模型训练慢),一般建议先采用较大学习速率,后采用较小学习速率。

优化器构造方法:

# SGD(Stochastic Gradient Descent) 随机梯度下降
# 模型参数、学习速率、动量
**optimizer = optim.SGD(model.parameters(), lr=0.01, momentum=0.9)**  

优化器调用方法:

for input, target in dataset:optimizer.zero_grad()            # 清空梯度output = model(input)res= loss(output, target)        # 计算损失函数res.backward()                   # 反向传播计算梯度optimizer.step()                 # 根据梯度优化参数

以 CIFAR-10 数据集为例:

import torch
import torchvision
from torch import nn
from torch.utils.data import DataLoader
from torch.utils.tensorboard import SummaryWriter# 加载数据集
dataset = torchvision.datasets.CIFAR10(root="Dataset", train=False,transform=torchvision.transforms.ToTensor(), download=False)
# 批量加载数据
dataloader = DataLoader(dataset, batch_size=64, shuffle=True)
writer = SummaryWriter("logs")class Model(nn.Module):def __init__(self):super(Model, self).__init__()self.model1 = nn.Sequential(nn.Conv2d(in_channels=3, out_channels=32, kernel_size=5, stride=1, padding=2),nn.MaxPool2d(kernel_size=2),nn.Conv2d(in_channels=32, out_channels=32, kernel_size=5, stride=1, padding=2),nn.MaxPool2d(kernel_size=2),nn.Conv2d(in_channels=32, out_channels=64, kernel_size=5, stride=1, padding=2),nn.MaxPool2d(kernel_size=2),nn.Flatten(),nn.Linear(in_features=1024, out_features=64),nn.Linear(in_features=64, out_features=10))def forward(self, x):x = self.model1(x)return xmodel = Model()
# 定义损失函数
loss = torch.nn.CrossEntropyLoss()
# 定义优化器
**optimizer = torch.optim.SGD(model.parameters(), lr=0.01)**# 训练 20 个 epoch
for epoch in range(20):totalloss = 0.0for data in dataloader:optimizer.zero_grad()    # 清空梯度imgs, targets = dataoutputs = model(imgs)lossres = loss(outputs, targets)    # 计算损失totalloss = totalloss + lossres     # 累加损失lossres.backward()                  # 反向传播计算梯度optimizer.step()                    # 更新模型参数print("Epoch{} : {}".format(epoch, totalloss))# 写入 TensorBoardwriter.add_scalar("train_loss", totalloss, epoch)writer.close()

如果模型在训练时过早出现 nan 或损失不收敛,可以尝试调整学习率,使用更小的学习率或更高级的优化器(如 Adam)。

预训练模型

PyTorch 主要提供搭建神经网络的核心工具,TorchVision 提供了一系列预训练模型、标准数据集(如 ImageNet、CIFAR-10 等)和图像变换工具(transforms)。预训练模型(如 VGG16)在 ImageNet 数据集上已经训练好,可以直接使用或者在此基础上微调。

VGG16 是一种经典的卷积神经网络,主要用于图像分类任务。VGG16 由多层卷积层、池化层和全连接层组成,features 部分用于提取图像特征,classifier 部分用于分类,最终输出1000个类别。

torchvision.models.vgg16(weights, progress)

progess 代表是否显示下载进度条,默认 True,表示在下载权重时显示进度条。

weights 是预训练权重,默认为 None 不加载预训练模型。权重 VGG16_Weights.IMAGENET1K_V1 适用于分类任务,基于 ImageNet 训练,包含完整的分类器(classifier 层),VGG16_Weights.DEFAULT 等同于 VGG16_Weights.IMAGENET1K_V1

import torchvision# 无预训练权重(随机初始化参数)
vgg16_false = torchvision.models.vgg16(weights=None)
# 使用 ImageNet 预训练参数
vgg16_true = torchvision.models.vgg16(weights=torchvision.models.VGG16_Weights.IMAGENET1K_V1)
# 默认使用 ImageNet 预训练权重
vgg16_default = torchvision.models.vgg16(weights=torchvision.models.VGG16_Weights.DEFAULT)

但是 VGG16 对于图像输入有严格要求,输入维度必须是 224 × 224 224 \times 224 224×224

# 图像预处理(按 VGG16 需要的格式)
transform = transforms.Compose([transforms.Resize(256),                 # 先缩放到 256transforms.CenterCrop(224),             # 再中心裁剪到 224transforms.ToTensor(),# 归一化transforms.Normalize(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225]),
])

权重 VGG16_Weights.IMAGENET1K_FEATURES 用于特征提取,不包含 classifier 部分权重,只能提取特征,不能进行分类(只是不包含预训练的分类器权重,并没有移除分类器层)。适用于迁移学习,可以用 features 层进行特征提取。

import torchvisionvgg16_feature = torchvision.models.vgg16(weights=torchvision.models.VGG16_Weights.IMAGENET1K_FEATURES)

VGG Model Structure

VGG((features): Sequential((0): Conv2d(3, 64, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))(1): ReLU(inplace=True)(2): Conv2d(64, 64, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))(3): ReLU(inplace=True)(4): MaxPool2d(kernel_size=2, stride=2, padding=0, dilation=1, ceil_mode=False)(5): Conv2d(64, 128, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))......(30): MaxPool2d(kernel_size=2, stride=2, padding=0, dilation=1, ceil_mode=False))(avgpool): AdaptiveAvgPool2d(output_size=(7, 7))(classifier): Sequential((0): Linear(in_features=25088, out_features=4096, bias=True)(1): ReLU(inplace=True)(2): Dropout(p=0.5, inplace=False)(3): Linear(in_features=4096, out_features=4096, bias=True)(4): ReLU(inplace=True)(5): Dropout(p=0.5, inplace=False)(6): Linear(in_features=4096, out_features=1000, bias=True))
)

分类任务

import torch
from PIL import Image
from torchvision import models, transforms# 加载 VGG16 预训练模型
model = models.vgg16(weights=models.VGG16_Weights.IMAGENET1K_V1)# 定义图像预处理步骤
transform = transforms.Compose([transforms.Resize(256),transforms.CenterCrop(224),transforms.ToTensor(),transforms.Normalize(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225]),
])img_path = r"Dataset/airplane.png"
img = Image.open(img_path)
input = transform(img)
# 添加 batch 维度
input = torch.reshape(input, (1, 3, 224, 224))# 进入推理模式
model.eval()
# 前向传播
with torch.no_grad():output = model(input)# 获取预测类别索引
predicted_class = torch.argmax(output)
# 获取 ImageNet 1000 类的类别名称
classes = models.VGG16_Weights.IMAGENET1K_V1.meta["categories"]
print(classes[predicted_class])

迁移学习微调模型

如果要迁移到 CIFAR-10 的分类任务,需要修改最后一层

from torch import nn
from torchvision import models# 加载 VGG16 预训练模型
model = models.vgg16(weights=models.VGG16_Weights.IMAGENET1K_V1)# 修改 classifier 部分(改为 10 类)
**model.classifier[6]** = nn.Linear(in_features=4096, out_features=10)

或者添加新层

model.classifier.add_module("7", nn.Linear(in_features=1000, out_features=10))

如果只训练最后一层,可以冻结前面的参数

for param in model.features.parameters():param.requires_grad = False  # 冻结 features 部分(不更新)

这样可以 保留原有的卷积特征,仅微调分类层,提高训练效率。

相关文章:

训练与优化

训练与优化 损失函数与反向传播 损失函数能够衡量神经网络输出与目标值之间的误差,同时为反向传播提供依据,计算梯度来优化网络中的参数。 torch.nn.L1Loss 计算所有预测值与真实值之间的绝对差。参数为 reduction : none:不对…...

VsCode美化 Json

1.扩展中输入:pretty json 2. (CtrlA)选择Json文本 示例:{ "name" : "runoob" , "alexa" :10000, "site" : null , "sites" :[ "Google" , "Runoob" , "T…...

基于Spring Boot的社区居民健康管理平台的设计与实现

目录 1 绪论 1.1 研究现状 1.2 研究意义 1.3 组织结构 2 技术介绍 2.1 平台开发工具和环境 2.2 Vue介绍 2.3 Spring Boot 2.4 MyBatis 2.5 环境搭建 3 系统需求分析 3.1 可行性分析 3.2 功能需求分析 3.3 系统用例图 3.4 系统功能图 4 系统设计 4.1 系统总体描…...

使用Java爬虫获取京东商品SKU信息的完整指南

在电商领域,商品SKU(Stock Keeping Unit)信息是商家和消费者都非常关注的内容。SKU信息不仅包括商品的基本属性(如价格、库存、规格等),还涉及到商品的动态数据(如促销信息、库存状态等&#xf…...

面试题之Vuex,sessionStorage,localStorage的区别

Vuex、localStorage 和 sessionStorage 都是用于存储数据的技术,但它们在存储范围、存储方式、应用场景等方面存在显著区别。以下是它们的详细对比: 1. 存储范围 Vuex: 是 Vue.js 的状态管理库,用于存储全局状态。 数据存储在内…...

ssm121基于ssm的开放式教学评价管理系统+vue(源码+包运行+LW+技术指导)

项目描述 临近学期结束,还是毕业设计,你还在做java程序网络编程,期末作业,老师的作业要求觉得大了吗?不知道毕业设计该怎么办?网页功能的数量是否太多?没有合适的类型或系统?等等。这里根据疫情当下,你想解决的问…...

【深度学习】Transformer入门:通俗易懂的介绍

【深度学习】Transformer入门:通俗易懂的介绍 一、引言二、从前的“读句子”方式三、Transformer的“超级阅读能力”四、Transformer是怎么做到的?五、Transformer的“多视角”能力六、Transformer的“位置记忆”七、Transformer的“翻译流程”八、Trans…...

大语言模型内容安全的方式有哪些

大语言模型内容安全的方式有哪些 LLM(大语言模型)内容安全方式主要是通过技术手段对模型生成的内容进行检测、过滤和干预,以确保输出符合道德、法律和社会规范。以下是一些常见的方式方法及其原理和著名的应用案例: 基于规则的过滤 原理:制定一系列明确的规则和模式,例…...

《深度学习》——ResNet网络

文章目录 ResNet网络ResNet网络实例导入所需库下载训练数据和测试数据设置每个批次的样本个数判断是否使用GPU定义残差模块定义ResNet网络模型导入GPU定义训练函数定义测试函数创建损失函数和优化器训练测试数据结果 ResNet网络 ResNet(Residual Network&#xff0…...

【Windows软件 - HeidiSQL】导出数据库

HeidSQL导出数据库 软件信息 具体操作 示例文件 选项分析 选项(1) 结果(1) -- -------------------------------------------------------- -- 主机: 127.0.0.1 -- 服务器版本: …...

FFmpeg 全面知识大纲梳理

1. FFmpeg 简介 FFmpeg 是什么: 一个开源的多媒体处理框架,用于处理音频、视频和流媒体。支持多种格式和编解码器。提供命令行工具和库(如 libavcodec, libavformat, libavfilter 等)。主要功能: 格式转换编解码流媒体处理音视频剪辑、合并、分离添加滤镜、特效压缩与优化…...

【达梦数据库】dblink连接[SqlServer/Mysql]报错处理

目录 背景问题1:无法测试以ODBC数据源方式访问的外部链接!问题分析&原因解决方法 问题2:DBLINK连接丢失问题分析&原因解决方法 问题3:DBIINK远程服务器获取对象[xxx]失败,错误洋情[[FreeTDS][SQL Server]Could not find stored proce…...

基于 Spring Boot 的社区居民健康管理系统部署说明书

目录 1 系统概述 2 准备资料 3 系统安装与部署 3.1 数据库部署 3.1.1 MySQL 的部署 3.1.2 Navicat 的部署 3.2 服务器部署 3.3 客户端部署 4 系统配置与优化 5 其他 基于 Spring Boot 的社区居民健康管理系统部署说明书 1 系统概述 本系统主要运用了 Spri…...

量化噪声介绍

量化噪声是在将模拟信号转换为数字信号的量化过程中产生的噪声。以下为你详细介绍: 1. 量化的基本概念 在模拟信号数字化过程中,采样是对模拟信号在时间上进行离散化,而量化则是对采样值在幅度上进行离散化。由于模拟信号的取值是连续的&am…...

java断点调试(debug)

在开发中,新手程序员在查找错误时, 这时老程序员就会温馨提示,可以用断点调试,一步一步的看源码执行的过程,从而发现错误所在。 重要提示: 断点调试过程是运行状态,是以对象的运行类型来执行的 断点调试介绍 断点调试是…...

最新智能优化算法:牛优化( Ox Optimizer,OX)算法求解经典23个函数测试集,MATLAB代码

一、牛优化算法 牛优化( OX Optimizer,OX)算法由 AhmadK.AlHwaitat 与 andHussamN.Fakhouri于2024年提出,该算法的设计灵感来源于公牛的行为特性。公牛以其巨大的力量而闻名,能够承载沉重的负担并进行远距离运输。这种…...

Redis7——基础篇(四)

前言:此篇文章系本人学习过程中记录下来的笔记,里面难免会有不少欠缺的地方,诚心期待大家多多给予指教。 基础篇: Redis(一)Redis(二)Redis(三) 接上期内容&…...

Git备忘录(三)

设置用户信息: git config --global user.name “itcast” git config --global user.email “ helloitcast.cn” 查看配置信息 git config --global user.name git config --global user.email $ git init $ git remote add origin gitgitee.com:XXX/avas.git $ git pull or…...

MySQL 之INDEX 索引(Index Index of MySQL)

MySQL 之INDEX 索引 1.4 INDEX 索引 1.4.1 索引介绍 索引:是排序的快速查找的特殊数据结构,定义作为查找条件的字段上,又称为键 key,索引通过存储引擎实现。 优点 大大加快数据的检索速度; 创建唯一性索引,保证数…...

Linux基础24-C语言之分支结构Ⅰ【入门级】

分支结构 问题抛出 我们在程序设计中往往会遇到如下问题,比如下面的函数计算: 也就是我们必须要通过一个条件的结果来选择下一步的操作,算法上属于一个分支结构,处于严重实现分支结构主要使用if语句。 条件判断 根据某个条件成…...

LeetCode47

LeetCode47 目录 题目描述示例思路分析代码段代码逐行讲解复杂度分析总结的知识点整合总结 题目描述 给定一个可包含重复数字的整数数组 nums,按任意顺序返回所有不重复的全排列。 示例 示例 1 输入: nums [1, 1, 2]输出: [[1, 1, 2],[1, 2, 1],[2, 1, 1] ]…...

C++中std::condition_variable_any、std::lock_guard 和 std::unique_

1、背景 在 C 多线程编程中,同步 和 互斥 是至关重要的概念。C 标准库提供了多种同步机制,其中 std::condition_variable_any、std::lock_guard 和 std::unique_lock 是经常被用到的工具。本文将详细介绍这三者的用途、区别、适用场景,并通过…...

详解AbstractQueuedSynchronizer(AQS)源码

引言 上篇文章讲解了CountDownLatch源码,底层是继承了AQS基类调用父类和重写父类方法实现的,本文将简介AQS源码和架构设计,帮助我们更深入理解多线程实战。 源码架构 1. 状态变量 state AQS 使用一个 int 类型的变量 state 来表示同步状态…...

【Unity动画】导入动画资源到项目中,Animator播放角色动画片段,角色会跟随着动画播放移动。

导入动画资源到项目中,Animator播放角色动画片段,角色会跟随着动画播放移动,但我只想要角色在原地播放动画。比如:播放一个角色Run动画,希望角色在原地奔跑,而不是产生了移动距离。 问题排查: 1.是否勾选…...

图解循环神经网络(RNN)

目录 1.循环神经网络介绍 2.网络结构 3.结构分类 4.模型工作原理 5.模型工作示例 6.总结 1.循环神经网络介绍 RNN(Recurrent Neural Network,循环神经网络)是一种专门用于处理序列数据的神经网络结构。与传统的神经网络不同&#xff0c…...

【数据结构】(9) 优先级队列(堆)

一、优先级队列 优先级队列不同于队列,队列是先进先出,优先级队列是优先级最高的先出。一般有两种操作:返回最高优先级对象,添加一个新对象。 二、堆 2.1、什么是堆 堆也是一种数据结构,是一棵完全二叉树&#xff0c…...

4、IP查找工具-Angry IP Scanner

在前序文章中,提到了多种IP查找方法,可能回存在不同场景需要使用不同的查找命令,有些不容易记忆,本文将介绍一个比较优秀的IP查找工具,可以应用在连接树莓派或查找IP的其他场景中。供大家参考。 Angry IP Scanner下载…...

【Linux】命令操作、打jar包、项目部署

阿华代码,不是逆风,就是我疯 你们的点赞收藏是我前进最大的动力!! 希望本文内容能够帮助到你!! 目录 一:Xshell下载 1:镜像设置 二:阿里云设置镜像Ubuntu 三&#xf…...

瑞萨RA-T系列芯片ADCGPT功能模块的配合使用

在马达或电源工程中,往往需要采集多路AD信号,且这些信号的优先级和采样时机不相同。本篇介绍在使用RA-T系列芯片建立马达或电源工程时,如何根据需求来设置主要功能模块ADC&GPT,包括采样通道打包和分组,GPT触发启动…...

python爬虫系列课程1:初识爬虫

python爬虫系列课程1:初识爬虫 一、爬虫的概念二、通用爬虫和自定义爬虫的区别三、开发语言四、爬虫流程一、爬虫的概念 网络爬虫(又被称为网页蜘蛛、网络机器人)就是模拟浏览器发送网络请求,接收请求响应,一种按照一定的规则,自动抓取互联网信息的程序。原则上,只要是…...