当前位置: 首页 > article >正文

2025-05-31 Python深度学习10——模型训练流程

文章目录

  • 1 数据准备
    • 1.1 下载与预处理
    • 1.2 数据加载
  • 2 模型构建
    • 2.1 自定义 CNN 模型
    • 2.2 GPU加速
  • 3 训练配置
    • 3.1 损失函数
    • 3.2 优化器
    • 3.3 训练参数
  • 4 训练循环
    • 4.1 训练模式 (`model.train()`)
    • 4.2 评估模式 (`model.eval()`)
  • 5 模型验证

本文环境:

  • Pycharm 2025.1
  • Python 3.12.9
  • Pytorch 2.6.0+cu124

​ 本文以 CIFAR-10 为例,介绍模型的大致训练流程。相关的 Python 包如下:

import torch
import torchvision
from torch import nn, Tensor
from torch.utils.data import DataLoader
from torch.utils.tensorboard import SummaryWriter
import time

1 数据准备

1.1 下载与预处理

​ 使用torchvision.datasets.CIFAR10下载 CIFAR-10 数据集(32x32 彩色图像,10 类),分为训练集(train=True,5 万张)和测试集(train=False,1 万张)。

# 准备数据集
train_data = torchvision.datasets.CIFAR10(root='./dataset',train=True,transform=torchvision.transforms.ToTensor(),download=True
)test_data = torchvision.datasets.CIFAR10(root='./dataset',train=False,transform=torchvision.transforms.ToTensor(),download=True
)
  • transform=torchvision.transforms.ToTensor():将图像转为 PyTorch 张量(Tensor),并自动归一化到 [0, 1] 范围。
  • download=True:若本地无数据,自动下载。

1.2 数据加载

​ 通过DataLoader分批次加载数据:

# 加载数据集
train_dataloader = DataLoader(train_data, batch_size=64)
test_dataloader = DataLoader(test_data, batch_size=64)
  • batch_size=64:每批次处理 64 张图片,平衡内存占用和训练效率。
  • 训练集默认不打乱(未设置shuffle),测试集可添加shuffle=True以增强评估可靠性。

2 模型构建

2.1 自定义 CNN 模型

MyModel是一个3层卷积神经网络(CNN):

  1. 卷积层:nn.Conv2d(3, 32, 5, 1, 2)

    输入通道 3(RGB),输出通道 32,5×5 卷积核,步长 1,填充 2(保持尺寸不变)。

  2. 池化层:nn.MaxPool2d(2)

    2×2最大池化,尺寸减半。

  3. 全连接层

    • nn.Linear(64 * 4 * 4, 64)

      将展平后的特征(64 通道×4×4尺寸)映射到 64 维。

    • nn.Linear(64, 10)

      最终输出 10 类。

image-20250527161255093
class MyModel(nn.Module):def __init__(self):super(MyModel, self).__init__()self.model = nn.Sequential(nn.Conv2d(3, 32, 5, 1, 2),nn.MaxPool2d(2),nn.Conv2d(32, 32, 5, 1, 2),nn.MaxPool2d(2),nn.Conv2d(32, 64, 5, 1, 2),nn.MaxPool2d(2),nn.Flatten(),nn.Linear(64 * 4 * 4, 64),nn.Linear(64, 10))def forward(self, x):return self.model(x)

2.2 GPU加速

​ 通过.to(device)将模型和数据移至 GPU(若可用),显著加速计算。

# 定义训练的设备
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")model = MyModel().to(device)  # 使用GPU

3 训练配置

3.1 损失函数

​ 使用nn.CrossEntropyLoss(),适用于多分类任务,计算预测概率与真实标签的交叉熵。

# 损失函数
loss_fn = nn.CrossEntropyLoss().to(device)  # 使用GPU

3.2 优化器

​ 使用torch.optim.SGD,随机梯度下降,学习率lr=1e-2,控制参数更新步长。

# 损失函数
loss_fn = nn.CrossEntropyLoss().to(device)  # 使用GPU

3.3 训练参数

  1. total_train_step:记录训练次数,用于日志和调试。
  2. total_test_step:记录测试次数,用于日志和调试。
  3. epoch=20:遍历完整数据集 20 次。
# 设置训练网络的一些参数
total_train_step = 0  # 记录训练次数
total_test_step = 0  # 记录测试次数
epoch = 20  # 训练的轮数

4 训练循环

数据加载  →  模型初始化  →  训练循环  →  测试评估  →  保存模型↑          ↑                  ↓           ↓└───TensorBoard日志 ←────── 参数更新 ←── 梯度计算

4.1 训练模式 (model.train())

  1. 前向传播:输入图像imgs,模型输出预测outputs
  2. 计算损失loss = loss_fn(outputs, targets),衡量预测误差。
  3. 反向传播
    • optimizer.zero_grad():清空梯度,避免累积。
    • loss.backward():计算梯度(链式法则)。
    • optimizer.step():更新模型参数。
  4. 日志记录:每 100 次训练记录损失和时间到 TensorBoard 中。
for i in range(epoch):print(f"------------第 {i + 1} 轮训练开始------------")# 训练步骤开始model.train()for data in train_dataloader:imgs, targets = dataimgs = imgs.to(device)  # 使用GPUtargets = targets.to(device)  # 使用GPUoutputs = model(imgs)loss = loss_fn(outputs, targets)# 反向传播optimizer.zero_grad()loss.backward()optimizer.step()total_train_step += 1if total_train_step % 100 == 0:end_time = time.time()print(f"第 {total_train_step} 次训练,Loss:{loss.item()},Time:{end_time - start_time}")writer.add_scalar("train_loss", loss.item(), total_train_step)start_time = time.time()

4.2 评估模式 (model.eval())

  1. 关闭梯度计算with torch.no_grad(),节省内存并加速。
  2. 测试指标
    • 总损失:累加所有批次的损失total_test_loss
    • 准确率:统计预测正确的样本数(outputs.argmax(dim=1) == targets)。
  3. 日志记录:每轮测试后保存损失和准确率到 TensorBoard。
  4. 保存模型:通过torch.save()方法将模型的 state_dict 保存到本地文件中。
    # 测试步骤开始model.eval()total_test_loss = 0total_accuracy = 0accuracy_rate = 0with torch.no_grad():for data in test_dataloader:imgs, targets = dataimgs = imgs.to(device)  # 使用GPUtargets = targets.to(device)  # 使用GPUoutputs: Tensor = model(imgs)loss = loss_fn(outputs, targets)total_test_loss += lossaccuracy = outputs.argmax(dim=1) == targetstotal_accuracy += accuracy.sum()total_test_step += 1accuracy_rate = total_accuracy / test_data_sizeprint(f"第 {i + 1} 轮测试,Loss:{total_test_loss},Accuracy:{total_accuracy} ({accuracy_rate})")writer.add_scalar("test_loss", total_test_loss, total_test_step)writer.add_scalar("test_accuracy", total_accuracy, total_test_step)writer.add_scalar("accuracy_rate", accuracy_rate, total_test_step)torch.save(model.state_dict(), f"model/my_model.pth")  # 保存模型writer.close()
image-20250531123058128

说明

  1. 训练与评估模式切换

    • model.train():启用 Dropout 和 BatchNorm 的训练行为(如随机丢弃神经元)。

    • model.eval():固定 Dropout 和 BatchNorm 的统计量,确保评估一致性。

  2. GPU 数据迁移

    需将输入数据 imgs 和标签 targets 均移至 GPU,否则会报错。

  3. 梯度清零

    避免梯度累加导致参数更新错误。

完整代码

# train_gpu_2.pyimport torch
import torchvision
from torch import nn, Tensor
from torch.utils.data import DataLoader
from torch.utils.tensorboard import SummaryWriter
import time# 定义训练的设备
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")# 准备数据集
train_data = torchvision.datasets.CIFAR10(root='./dataset',train=True,transform=torchvision.transforms.ToTensor(),download=True
)test_data = torchvision.datasets.CIFAR10(root='./dataset',train=False,transform=torchvision.transforms.ToTensor(),download=True
)# 数据集大小
train_data_size = len(train_data)
test_data_size = len(test_data)print(f"训练集数量:{train_data_size}")
print(f"测试集数量:{test_data_size}")# 加载数据集
train_dataloader = DataLoader(train_data, batch_size=64)
test_dataloader = DataLoader(test_data, batch_size=64)# 创建网络模型
class MyModel(nn.Module):def __init__(self):super(MyModel, self).__init__()self.model = nn.Sequential(nn.Conv2d(3, 32, 5, 1, 2),nn.MaxPool2d(2),nn.Conv2d(32, 32, 5, 1, 2),nn.MaxPool2d(2),nn.Conv2d(32, 64, 5, 1, 2),nn.MaxPool2d(2),nn.Flatten(),nn.Linear(64 * 4 * 4, 64),nn.Linear(64, 10))def forward(self, x):return self.model(x)model = MyModel().to(device)  # 使用GPU# 损失函数
loss_fn = nn.CrossEntropyLoss().to(device)  # 使用GPU# 优化器
lr = 1e-2
optimizer = torch.optim.SGD(model.parameters(), lr=lr)# 设置训练网络的一些参数
total_train_step = 0  # 记录训练次数
total_test_step = 0  # 记录测试次数
epoch = 20  # 训练的轮数# 添加 tensorboard
writer = SummaryWriter("../logs_train")for i in range(epoch):print(f"------------第 {i + 1} 轮训练开始------------")start_time = time.time()# 训练步骤开始model.train()for data in train_dataloader:imgs, targets = dataimgs = imgs.to(device)  # 使用GPUtargets = targets.to(device)  # 使用GPUoutputs = model(imgs)loss = loss_fn(outputs, targets)# 反向传播optimizer.zero_grad()loss.backward()optimizer.step()total_train_step += 1if total_train_step % 100 == 0:end_time = time.time()print(f"第 {total_train_step} 次训练,Loss:{loss.item()},Time:{end_time - start_time}")writer.add_scalar("train_loss", loss.item(), total_train_step)start_time = time.time()# 测试步骤开始model.eval()total_test_loss = 0total_accuracy = 0accuracy_rate = 0with torch.no_grad():for data in test_dataloader:imgs, targets = dataimgs = imgs.to(device)  # 使用GPUtargets = targets.to(device)  # 使用GPUoutputs: Tensor = model(imgs)loss = loss_fn(outputs, targets)total_test_loss += lossaccuracy = outputs.argmax(dim=1) == targetstotal_accuracy += accuracy.sum()total_test_step += 1accuracy_rate = total_accuracy / test_data_sizeprint(f"第 {i + 1} 轮测试,Loss:{total_test_loss},Accuracy:{total_accuracy} ({accuracy_rate})")writer.add_scalar("test_loss", total_test_loss, total_test_step)writer.add_scalar("test_accuracy", total_accuracy, total_test_step)writer.add_scalar("accuracy_rate", accuracy_rate, total_test_step)torch.save(model.state_dict(), f"model/my_model.pth")  # 保存模型writer.close()

5 模型验证

​ 准备待验证的图片,放在 imgae 目录下。

image-20250531123319819

​ 编写 test.py 文件,用于验证模型。

# test.pyimport torch
import torchvision
from PIL import Image
from torch import nn# 定义图片路径
image_path = "image/dog.png"# 打开图片并转换为RGB格式
image = Image.open(image_path).convert('RGB')# 定义图片转换操作
transform = torchvision.transforms.Compose([torchvision.transforms.Resize((32, 32)),  # 将图片大小调整为32x32torchvision.transforms.ToTensor()  # 将图片转换为张量
])# 对图片进行转换操作
image = transform(image).reshape(1, 3, 32, 32)# 定义模型类
class MyModel(nn.Module):def __init__(self):super(MyModel, self).__init__()# 定义模型结构self.model = nn.Sequential(nn.Conv2d(3, 32, 5, 1, 2),  # 第一个卷积层,输入通道数为3,输出通道数为32,卷积核大小为5,步长为1,填充为2nn.MaxPool2d(2),  # 最大池化层,池化核大小为2nn.Conv2d(32, 32, 5, 1, 2),  # 第二个卷积层,输入通道数为32,输出通道数为32,卷积核大小为5,步长为1,填充为2nn.MaxPool2d(2),  # 最大池化层,池化核大小为2nn.Conv2d(32, 64, 5, 1, 2),  # 第三个卷积层,输入通道数为32,输出通道数为64,卷积核大小为5,步长为1,填充为2nn.MaxPool2d(2),  # 最大池化层,池化核大小为2nn.Flatten(),  # 展平操作nn.Linear(64 * 4 * 4, 64),  # 全连接层,输入维度为64*4*4,输出维度为64nn.Linear(64, 10)  # 全连接层,输入维度为64,输出维度为10)def forward(self, x):return self.model(x)# 加载模型参数
model_dict = torch.load('model/my_model.pth')
model = MyModel()
model.load_state_dict(model_dict)
model.to('cuda')# 设置模型为评估模式
model.eval()
# 关闭梯度计算
with torch.no_grad():# 将图片转换为GPU格式image = image.to('cuda')# 进行模型推理output = model(image)# 打印输出结果
print(output)
# 打印输出结果中最大值的索引
print(output.argmax(1))

​ 验证结果如下,表明 dog.png 图片的预测结果索引为 5,即第 6 类预测。

image-20250531123550636

​ 依据分类规则,预测结果为 dog,是正确的。

image-20250531123902927

相关文章:

2025-05-31 Python深度学习10——模型训练流程

文章目录 1 数据准备1.1 下载与预处理1.2 数据加载 2 模型构建2.1 自定义 CNN 模型2.2 GPU加速 3 训练配置3.1 损失函数3.2 优化器3.3 训练参数 4 训练循环4.1 训练模式 (model.train())4.2 评估模式 (model.eval()) 5 模型验证 本文环境: Pycharm 2025.1Python 3.1…...

卷积神经网络(CNN)、YOLO和人脸识别之间的关系

核心关系图解 TEXT 摄像头图像 → [YOLO:人脸检测] → 定位人脸位置 → [CNN:特征提取] → 人脸特征向量 → [人脸识别系统] → 身份匹配 通俗比喻 想象你在一个拥挤的火车站找人: YOLO 是你的"快速扫描眼": 一眼扫…...

K8S StatefulSet 快速开始

其实这篇文章的梗概已经写了很久了,中间我小孩出生了,从此人间多了一份牵挂。抽出一些时间去办理新生儿相关手续。初为人父确实艰辛,就像学技术一样,都需要有极大的耐心,付出很多的时间。 一、引子 1.1、独立的存储 …...

重新测试deepseek Jakarta EE 10编程能力

听说deepseek做了一个小更新,我重新测试了一下Jakarta EE 10编程能力;有点进步,遗漏的功能比以前少了。 采用Jakarta EE 10 编写员工信息表维护表,包括员工查询与搜索、员工列表、新增员工、删除员工,修改员工&#xf…...

nav2笔记-250603

合作背景: AMD与Open Navigation在过去几个月里进行了合作,旨在向ROS 2社区展示AMD强大的Ryzen AI、Embedded和Kria能力。 演示内容: 帖子提到,他们已经开始展示如何使用Ryzen AI为自主机器人产品提供动力,在各种现实世…...

指纹识别+精准化POC攻击

开发目的 解决漏洞扫描器的痛点 第一就是扫描量太大,对一个站点扫描了大量的无用 POC,浪费时间 指纹识别后还需要根据对应的指纹去进行 payload 扫描,非常的麻烦 开发思路 我们的思路分为大体分为指纹POC扫描 所以思路大概从这几个方面…...

LeetCode[404]左叶子之和

思路: 题目要求求出左叶子的和,左叶子的条件是左右节点为空且是左子树的叶子节点才叫左叶子节点,那么右子树的左叶子节点的和是什么呢?这样想就引出了递归的顺序,后序遍历,求出左右子树的节点和&#xff0c…...

mac环境下的python、pycharm和pip安装使用

Python安装 Mac环境下的python安装 下载地址:https://www.jetbrains.com.cn/pycharm/ 一直点击下一步即可完成 在应用程序中会多了两个图标 IDLE 和 Python launcher IDLE支持在窗口中直接敲python命令并立即执行,双击即可打开 Python launcher双击打…...

C语言多级指针深度解析:从一级到三级的奥秘

资料合集下载链接: ​​https://pan.quark.cn/s/472bbdfcd014​​ 在C语言中,指针是理解内存和进行底层编程的关键。我们知道,一个一级指针存储的是一个变量的内存地址。但C语言的强大之处在于,指针本身也可以有自己的地址,而存储这个指针的地址的变量,就是一个更高层次…...

uni-app学习笔记十九--pages.json全局样式globalStyle设置

pages.json 页面路由 pages.json 文件用来对 uni-app 进行全局配置,决定页面文件的路径、窗口样式、原生的导航栏、底部的原生tabbar 等。 导航栏高度为 44px (不含状态栏),tabBar 高度为 50px (不含安全区)。 它类似微信小程序中app.json的页面管理部…...

BUUCTF[极客大挑战 2019]Havefun 1题解

BUUCTF[极客大挑战 2019]Havefun 1题解 题目分析解题理解代码逻辑:构造Payload: 总结 题目分析 生成靶机,进入网址: 首页几乎没有任何信息,公式化F12打开源码,发现一段被注释的源码: 下面我们…...

【基础】Unity中Camera组件知识点

一、投影模式 (Projection) 1. 透视模式 (Perspective) 原理:模拟人眼,近大远小(锥形体视锥) 核心参数: Field of View (FOV):垂直视场角 典型值:第一人称 60-90,驾驶舱 30-45 特…...

Tomcat优化篇

目录 一、Tomcat自身配置 1.Tomcat管理页面 2. 禁用AJP服务 3.Executor优化 4.三种运行模式 5.web.xml 6.Host标签 7.Context标签 8.启动速度优化 9.其他方面 二、JMeter测试 笔者推荐 一、Tomcat自身配置 1.Tomcat管理页面 我们可以打开Tomcat的管理页面&#xff…...

Temporal Fusion Transformer(TFT)扩散模型时间序列预测模型

1. TFT 简介 Temporal Fusion Transformer(TFT)模型是一种专为时间序列预测设计的高级深度学习模型。它结合了神经网络的多种机制处理时间序列数据中的复杂关系。TFT 由 Lim et al. 于 2019年提出,旨在处理时间序列中的不确定性和多尺度的依…...

【LangServe部署流程】5 分钟部署你的 AI 服务

目录 一、LangServe简介 二、环境准备 1. 安装必要依赖 2. 编写一个 LangChain 可运行链(Runnable) 3. 启动 LangServe 服务 4. 启动服务 5. 使用 API 进行调用 三、可选:访问交互式 Swagger 文档 四、基于 LangServe 的 RAG 应用部…...

攻防世界-unseping

进入环境 在获得的场景中发现PHP代码并进行分析 编写PHP编码 得到 Tzo0OiJlYXNlIjoyOntzOjEyOiIAZWFzZQBtZXRob2QiO3M6NDoicGluZyI7czoxMDoiAGVhc2UAYXJncyI7YToxOntpOjA7czozOiJwd2QiO319 将其传入 想执行ls,但是发现被过滤掉了 使用环境变量进行绕过 $a new…...

微软推出 Bing Video Creator,免费助力用户轻松创作 AI 视频

2025 年 6 月 2 日,微软正式在自家 Bing 应用中上线了一项名为 “Bing Video Creator” 的新功能,为广大用户带来了全新的创作体验。 Bing Video Creator 背靠 OpenAI 当红的 Sora 视频生成模型,用户只需输入文字描述,就能直接生…...

(13)java+ selenium->元素定位大法之By_partial_link_text

1.简介 在上一篇中我们说了link_text,目前我们接着看partial link text,顾名思义是通过链接定位的(官方说法:超链接文本定位)。我们在上一篇的文章末尾有提到,这种方式的定位属于模糊匹配定位,什么是partial link text呢,看到part这个单词我们就可以知道,当这个文字超…...

Xget 正式发布:您的高性能、安全下载加速工具!

您可以通过 star 我固定的 GitHub 存储库来支持我,谢谢!以下是我的一些 GitHub 存储库,很有可能对您有用: tzst Xget Prompt Library 原文 URL:https://blog.xi-xu.me/2025/06/02/xget-launch-high-performance-sec…...

[yolov11改进系列]基于yolov11使用FasterNet替换backbone用于轻量化网络的python源码+训练源码

【FasterNet介绍】 为了设计快速神经网络,许多工作都集中在减少浮点运算的数量(FLOPs)上。 然而,我们观察到FLOPs的减少并不一定会导致延迟的类似程度的减少。 这主要源于低效率的每秒浮点运算(FLOPS)。 为了实现更快的网络&#…...

一周学会Pandas2之Python数据处理与分析-Pandas2数据绘图与可视化

锋哥原创的Pandas2 Python数据处理与分析 视频教程: 2025版 Pandas2 Python数据处理与分析 视频教程(无废话版) 玩命更新中~_哔哩哔哩_bilibili Pandas 集成了 Matplotlib,提供了简单高效的绘图接口,使数据可视化变得直观便捷。本指南将详…...

企业级安全实践:SSL/TLS 加密与权限管理(一)

引言 ** 在数字化转型的浪潮中,企业对网络的依赖程度与日俱增,从日常办公到核心业务的开展,都离不开网络的支持。与此同时,网络安全问题也日益严峻,成为企业发展过程中不可忽视的重要挑战。 一旦企业遭遇网络安全事…...

2025——》VSCode Windows 最新安装指南/VSCode安装完成后如何验证是否成功?2025最新VSCode安装配置全攻略

1.VSCode Windows 最新安装指南: 以下是 2025 年 Windows 系统下安装 Visual Studio Code(VSCode)的最新指南,结合官方文档与实际操作经验整理而成: 一、下载官方安装包: 1.访问官网: 打开浏览器,进入 VSCode 官方下载页面https://code.visualstudio.com/Download 2…...

RabbitMQ如何保证消息可靠性

RabbitMQ是一个流行的开源消息代理,它提供了可靠的消息传递机制,广泛应用于分布式系统和微服务架构中。在现代应用中,确保消息的可靠性至关重要,以防止消息丢失和重复处理。本文将详细探讨RabbitMQ如何通过多种机制保证消息的可靠…...

【MATLAB代码】制导——三点法,二维平面下的例程|运动目标制导,附完整源代码

三点法制导是一种导弹制导策略,主要用于确保导弹能够准确追踪并击中移动目标。该方法通过计算导弹、目标和制导站之间的相对位置关系,实现对目标的有效制导。 本文给出MATLAB下的三点法例程,模拟平面上捕获运动目标的情况订阅专栏后可直接查看源代码,粘贴到MATLAB空脚本中即…...

Spring Security用户管理机制详解

UserDetailsService契约解析 核心方法解析 UserDetailsService接口仅定义了一个关键方法loadUserByUsername(),其方法签名如下: public interface UserDetailsService {UserDetails loadUserByUsername(String username) throws UsernameNotFoundException; }该方法作为用…...

如何爬取google应用商店的应用分类呢?

以下是爬取Google Play商店应用包名(package name)和对应分类的完整解决方案,采用ScrapyPlaywright组合应对动态渲染页面,并处理反爬机制: 完整爬虫实现 1. 安装必要库 # 卸载现有安装pip uninstall playwright scrapy-playwright -y# 重新…...

SQL Relational Algebra(数据库关系代数)

目录 What is an “Algebra” What is Relational Algebra? Core Relational Algebra Selection Projection Extended Projection Product(笛卡尔积) Theta-Join Natural Join Renaming Building Complex Expressions Sequences of Assignm…...

如何安装huaweicloud-sdk-core-3.1.142.jar到本地仓库?

如何安装huaweicloud-sdk-core-3.1.142.jar到本地仓库? package com.huaweicloud.sdk.core.auth does not exist 解决方案 # 下载huaweicloud-sdk-core-3.1.142.jar wget https://repo1.maven.org/maven2/com/huaweicloud/sdk/huaweicloud-sdk-core/3.1.142/huawe…...

Electron桌面应用下,在拍照、展示pdf等模块时,容易导致应用白屏

Electron 应用白屏问题分析与解决方案 Electron 应用中拍照、PDF展示等模块导致白屏的常见原因通常与内存泄漏、渲染进程崩溃或资源加载超时有关。以下是具体排查与解决方法: 检查内存泄漏 项目中,分析代码,高频操作或未释放的资源可能导致…...