一个完整的神经网络训练流程详解(附 PyTorch 示例)
🧠 一个完整的神经网络训练流程详解(附 PyTorch 示例)
📌 第一部分:神经网络训练流程概览(总)
在深度学习中,构建和训练一个神经网络模型并不是简单的“输入数据、得到结果”这么简单。整个过程是一个系统化、模块化的工程,涵盖了从原始数据到最终模型部署的完整生命周期。
以下是一个完整的神经网络训练流程概览表,帮助你快速理解每个环节的作用和相互关系:
步骤编号 | 流程名称 | 关键操作 | 目标/作用 |
---|---|---|---|
1 | 数据准备 | 加载、清洗、标准化、划分训练集/验证集/测试集 | 为模型提供结构化、干净的输入数据 |
2 | 模型定义 | 设计网络结构,选择激活函数、初始化参数 | 构建具备预测能力的模型框架 |
3 | 损失函数选择 | 定义目标函数(如交叉熵、均方误差) | 衡量模型预测与真实值之间的差距 |
4 | 优化器设置 | 选择优化算法(如 Adam、SGD)、配置学习率等参数 | 决定如何利用梯度更新模型参数 |
5 | 训练循环 | 正向传播 → 反向传播 → 参数更新 | 模型学习的核心机制 |
6 | 验证与调参 | 在验证集上评估性能,调整超参数 | 防止过拟合,提高泛化能力 |
7 | 测试与评估 | 在测试集上评估最终性能 | 客观评价模型在未知数据上的表现 |
8 | 模型保存与部署 | 保存模型参数、转换格式、部署上线 | 将模型应用于实际场景 |
关于第5部分的内容,可以看我的另一篇文章:如何理解神经网络训练的循环过程
✅ 一句话总结第一部分:
神经网络训练是一个端到端的过程,包括从数据预处理到模型部署的八大核心步骤。
🧩 第二部分:详细讲解每一步流程(分)
我们接下来以一个具体的图像分类任务为例(如 MNIST 手写数字识别),用 PyTorch 来实现每一个步骤。
1️⃣ 数据准备
⭐ 功能说明:
- 加载并预处理数据
- 划分训练集与测试集
- 构造
DataLoader
以便批量读取数据
✅ 代码示例(PyTorch):
import torch
from torchvision import datasets, transforms
from torch.utils.data import DataLoader# 数据预处理:将图像转为张量,并做归一化
transform = transforms.Compose([transforms.ToTensor(),transforms.Normalize((0.5,), (0.5,))
])# 加载 MNIST 数据集
train_dataset = datasets.MNIST(root='./data', train=True, download=True, transform=transform)
test_dataset = datasets.MNIST(root='./data', train=False, download=True, transform=transform)# 构建 DataLoader
batch_size = 64
train_loader = DataLoader(train_dataset, batch_size=batch_size, shuffle=True)
test_loader = DataLoader(test_dataset, batch_size=batch_size)
2️⃣ 模型定义
⭐ 功能说明:
- 定义网络结构(这里使用一个简单的全连接网络)
- 初始化参数(一般自动完成)
✅ 代码示例(PyTorch):
import torch.nn as nnclass SimpleNet(nn.Module):def __init__(self):super(SimpleNet, self).__init__()self.fc1 = nn.Linear(28 * 28, 128)self.relu = nn.ReLU()self.fc2 = nn.Linear(128, 10)def forward(self, x):x = x.view(-1, 28*28) # 展平图像x = self.fc1(x)x = self.relu(x)x = self.fc2(x)return xmodel = SimpleNet()
3️⃣ 损失函数选择
⭐ 功能说明:
- 分类任务常用交叉熵损失函数
✅ 代码示例:
criterion = nn.CrossEntropyLoss()
4️⃣ 优化器设置
⭐ 功能说明:
- 使用 Adam 优化器进行参数更新
✅ 代码示例:
optimizer = torch.optim.Adam(model.parameters(), lr=0.001)
5️⃣ 训练循环
⭐ 功能说明:
- 实现完整的训练迭代流程:
- 正向传播
- 损失计算
- 反向传播
- 参数更新
✅ 代码示例:
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
model.to(device)num_epochs = 5for epoch in range(num_epochs):model.train()running_loss = 0.0for images, labels in train_loader:images, labels = images.to(device), labels.to(device)# 正向传播outputs = model(images)loss = criterion(outputs, labels)# 反向传播 + 参数更新optimizer.zero_grad()loss.backward()optimizer.step()running_loss += loss.item()print(f'Epoch [{epoch+1}/{num_epochs}], Loss: {running_loss/len(train_loader):.4f}')
6️⃣ 验证与调参(可选)
⭐ 功能说明:
- 监控验证集损失或准确率
- 防止过拟合,提前停止训练
✅ 代码片段(验证阶段):
def evaluate(model, data_loader):model.eval()correct = 0total = 0with torch.no_grad():for images, labels in data_loader:images, labels = images.to(device), labels.to(device)outputs = model(images)_, predicted = torch.max(outputs.data, 1)total += labels.size(0)correct += (predicted == labels).sum().item()return 100 * correct / totalval_acc = evaluate(model, test_loader)
print(f'Validation Accuracy: {val_acc:.2f}%')
7️⃣ 测试与评估
⭐ 功能说明:
- 最终在测试集上评估模型性能
✅ 代码复用上面的 evaluate()
即可
8️⃣ 模型保存与部署
⭐ 功能说明:
- 保存模型用于后续推理或上线使用
✅ 代码示例:
# 保存模型参数
torch.save(model.state_dict(), 'mnist_model.pth')# 加载模型参数
model.load_state_dict(torch.load('mnist_model.pth'))
🎯 第三部分:总结整个流程(总)
一个完整的神经网络训练流程是一个系统性、模块化的过程,主要包括以下八个关键步骤:
- 数据准备:清洗、标准化、构建 DataLoader
- 模型定义:设计合适的网络结构
- 损失函数选择:衡量预测误差
- 优化器设置:决定参数更新方式
- 训练循环执行:正向传播 → 反向传播 → 参数更新
- 验证与调参:防止过拟合,调整超参数
- 测试与评估:对模型性能进行最终评估
- 模型保存与部署:将模型落地应用
通过这一系列流程,我们可以从零开始训练出一个具备实用价值的神经网络模型,并将其应用于现实问题中。
💡 补充建议(可根据需要扩展)
- 增加可视化部分(如 TensorBoard 或 matplotlib 绘图)
- 添加早停(Early Stopping)机制
- 使用更复杂的网络(CNN、Transformer 等)
- 多 GPU 支持(DDP、DataParallel)
- 使用混合精度训练(AMP)
- 介绍模型压缩与量化(便于部署)
相关文章:
一个完整的神经网络训练流程详解(附 PyTorch 示例)
🧠 一个完整的神经网络训练流程详解(附 PyTorch 示例) 📌 第一部分:神经网络训练流程概览(总) 在深度学习中,构建和训练一个神经网络模型并不是简单的“输入数据、得到结果”这么简…...

OpenCV 图形API(77)图像与通道拼接函数-----对图像进行几何变换函数remap()
操作系统:ubuntu22.04 OpenCV版本:OpenCV4.9 IDE:Visual Studio Code 编程语言:C11 算法描述 对图像应用一个通用的几何变换。 函数 remap 使用指定的映射对源图像进行变换: dst ( x , y ) src ( m a p x ( x , y ) , m a p y…...
windows通过wsl安装ubuntu20.04
1 *.bat文件安装hyper-v pushd "%~dp0" dir /b %SystemRoot%\servicing\Packages\*Hyper-V*.mum >hyper-v.txt for /f %%i in (findstr /i . hyper-v.txt 2^>nul) do dism /online /norestart /add-package:"%SystemRoot%\servicing\Packages\%%i"…...

Spring AI 入门(持续更新)
介绍 Spring AI 是 Spring 项目中一个面向 AI 应用的模块,旨在通过集成开源框架、提供标准化的工具和便捷的开发体验,加速 AI 应用程序的构建和部署。 依赖 <!-- 基于 WebFlux 的响应式 SSE 传输 --> <dependency><groupId>org.spr…...

QUIC协议优化:HTTP_3环境下的超高速异步抓取方案
摘要 随着 QUIC 和 HTTP/3 的普及,基于 UDP 的连接复用与内置加密带来了远超 HTTP/2 的性能提升,可显著降低连接握手与拥塞恢复的开销。本文以爬取知乎热榜数据为目标,提出一种基于 HTTPX aioquic 的异步抓取方案,并结合代理 IP设…...

uni-app实现完成任务解锁拼图功能
界面如下 代码如下 <template><view class"puzzle-container"><view class"puzzle-title">任务进度 {{completedCount}}/{{totalPieces}}</view><view class"puzzle-grid"><viewv-for"(piece, index) in…...
Vue3 中当组件嵌套层级较深导致 ref 无法直接获取子组件实例时,可以通过 provide/inject + 回调函数的方式实现子组件方法传递到父组件
需求:vue3中使用defineExposeref调用子组件方法报错不是一个function 思路:由于组件嵌套层级太深导致ref失效,通过provide/inject 回调函数来实现多层穿透 1. 父组件提供「方法注册函数」 父组件通过 provide 提供一个用于接收子组件方法…...
关于 js:3. 闭包、作用域、内存模型
一、闭包的本质:函数 其词法作用域环境 闭包(Closure)的本质可以概括为: 闭包是一个函数,以及它定义时捕获的词法作用域中的变量集合。 这意味着:即使外部函数已经返回或作用域结束,只要有内…...

数据链路层(MAC 地址)
目录 一、前言: 二、以太网: 三、MAC 地址的作用: 四、ARP协议: 一、前言: 数据链路层主要负责相邻两个节点之间的数据传输,其中,最常见数据链路层的协议有 以太网(通过光纤 / 网…...

基于DQN的自动驾驶小车绕圈任务
1.任务介绍 任务来源: DQN: Deep Q Learning |自动驾驶入门(?) |算法与实现 任务原始代码: self-driving car 最终效果: 以下所有内容,都是对上面DQN代码的改进&#…...
terraform resource创建了5台阿里云ecs,如要使用terraform删除其中一台主机,如何删除?
在 Terraform 中删除阿里云 5 台 ECS 实例中的某一台,具体操作取决于你创建资源时使用的 多实例管理方式(count 或 for_each)。以下是详细解决方案: 方法一:使用 for_each(推荐) 如果创建时使…...

【Linux】Linux工具(1)
3.Linux工具(1) 文章目录 3.Linux工具(1)Linux 软件包管理器 yum什么是软件包关于 rzsz查看软件包——yum list命令如何安装软件如何卸载软件补充——yum如何找到要安装软件的下载地址 Linux开发工具Linux编辑器-vim使用1.vim的基…...
探索大语言模型(LLM):词袋法(Bag of Words)原理与实现
文章目录 引言一、词袋法原理1.1 核心思想1.2 实现步骤 二、数学公式2.1 词频表示2.2 TF-IDF加权(可选) 三、示例表格3.1 构建词汇表3.2 文本向量化(词频) 四、Python代码实现4.1 基础实现(手动计算)4.2 输…...
vue引入物理引擎matter.js
vue引入物理引擎matter.js 在 Vue 项目中集成 Matter.js 物理引擎的步骤如下: 1. 安装 Matter.js npm install matter-js # 或 yarn add matter-js2. 创建 Vue 组件 <template><div ref="physicsContainer" class="physics-container"><…...

基于 Spring Boot 瑞吉外卖系统开发(十一)
基于 Spring Boot 瑞吉外卖系统开发(十一) 菜品启售和停售 “批量启售”、“批量停售”、操作列的售卖状态绑定单击事件,触发单击事件时,最终携带需要修改售卖状态的菜品id以post请求方式向“/dish/status/{params.status}”发送…...
支持鸿蒙next的uts插件
*本文共四个功能函数,相当于四个插件。作者为了偷懒写成了一个插件,调对应的函数即可。 1、chooseImageHarmony函数:拉起相册选择图片并转为Base64 2、takePhotoAndConvertToBase64函数:拉起相机拍照并转为Base64 3、openBrows…...

深入理解负载均衡:传输层与应用层的原理与实战
目录 前言1. 传输层(Layer 4)负载均衡1.1 工作层级与核心机制1.2 实现方式详解1.3 优缺点分析1.4 典型实现工具 2. 应用层(Layer 7)负载均衡2.1 工作层级与核心机制2.2 实现方式解析2.3 优缺点分析2.4 常用实现工具 3. Layer 4 与…...

WPF之Slider控件详解
文章目录 1. 概述2. 基本属性2.1 值范围属性2.2 滑动步长属性2.3 刻度显示属性2.4 方向属性2.5 选择范围属性 3. 事件处理3.1 值变化事件3.2 滑块拖动事件 4. 样式和模板自定义4.1 基本样式设置4.2 控件模板自定义 5. 数据绑定5.1 绑定到ViewModel5.2 同步多个控件 6. 实际应用…...
极狐GitLab 如何将项目共享给群组?
极狐GitLab 是 GitLab 在中国的发行版,关于中文参考文档和资料有: 极狐GitLab 中文文档极狐GitLab 中文论坛极狐GitLab 官网 共享项目和群组 (BASIC ALL) 在极狐GitLab 16.10 中,更改为在成员页面的成员选项卡上显示被邀请群组成员…...

企业微信自建消息推送应用
企业微信自建应用来推送消息 前言 最近有个给特定部门推送消息的需求,所以配置一个应用专门用来推送消息。实现过程大致为:服务器生成每天的报告,通过调用API来发送消息。以前一直都是发邮件,整个邮箱里全是报告文件,…...
【React】Hooks useReducer 详解,让状态管理更可预测、更高效
1.背景 useReducer是React提供的一个高级Hook,没有它我们也可以正常开发,但是useReducer可以使我们的代码具有更好的可读性,可维护性。 useReducer 跟 useState 一样的都是帮我们管理组件的状态的,但是呢与useState不同的是 useReducer 是集…...

日志之ClickHouse部署及替换ELK中的Elasticsearch
文章目录 1 ELK替换1.1 Elasticsearch vs ClickHouse1.2 环境部署1.2.1 zookeeper 集群部署1.2.2 Kafka 集群部署1.2.3 FileBeat 部署1.2.4 clickhouse 部署1.2.4.1 准备步骤1.2.4.2 添加官方存储库1.2.4.3 部署&启动&连接1.2.4.5 基本配置服务1.2.4.6 测试创建数据库和…...
亚远景-ASPICE vs ISO 21434:汽车软件开发标准的深度对比
ASPICE(Automotive SPICE)和ISO 21434是汽车软件开发领域的两大核心标准,分别聚焦于过程质量与网络安全。以下从核心目标、覆盖范围、实施重点、协同关系及行业价值五个维度进行深度对比分析: 一、核心目标对比 ASPICE࿱…...
51单片机快速成长路径
作为在嵌入式领域深耕18年的工程师,分享一条经过工业验证的51单片机快速成长路径,全程干货无注水: 一、突破认知误区(新手必看) 不要纠结于「汇编还是C」:现代开发90%场景用C,掌握指针和内存管…...
使用 NGINX 实现 HTTP Basic 认证ngx_http_auth_basic_module 模块
一、前言 在 Web 应用中,对部分资源进行访问控制是十分常见的需求。除了基于 IP 限制、JWT 验证、子请求校验等方式外,最经典也最简单的一种方式便是 HTTP Basic Authentication。NGINX 提供的 ngx_http_auth_basic_module 模块支持基于用户名和密码的基…...

解构与重构:自动化测试框架的进阶认知之旅
目录 一、自动化测试的介绍 (一)自动化测试的起源与发展 (二)自动化测试的定义与目标 (三)自动化测试的适用场景 二、什么是自动化测试框架 (一)自动化测试框架的定义 &#x…...

DockerDesktop替换方案
背景 由于DockerDesktop并非开源软件,如果在公司使用,可能就有一些限制,那是不是除了使用DockerDesktop外,就没其它办法了呢,现在咱们来说说替换方案。 WSL WSL是什么,可自行百度,这里引用WS…...

力扣热题100之搜索二维矩阵 II
题目 编写一个高效的算法来搜索 m x n 矩阵 matrix 中的一个目标值 target 。该矩阵具有以下特性: 每行的元素从左到右升序排列。 每列的元素从上到下升序排列。 代码 方法一:直接全体遍历 这个方法很直接,但是居然没有超时,…...

docker操作镜像-以mysql为例
Docker安装使用-CSDN博客 docker操作镜像-以mysql为例 当安装一个新的镜像时可以登录https://hub.docker.com/直接搜索想要安装的镜像,查看文档 1)拉取镜像 docker pull mysql 或者 docker pull mysql:版本号 然后直接跳到第4)步即可 2…...

使用OpenCV 和 Dlib 进行卷积神经网络人脸检测
文章目录 引言1.准备工作2.代码解析2.1 导入必要的库2.2 加载CNN人脸检测模型2.3 加载并预处理图像2.4 进行人脸检测2.5 绘制检测结果2.6 显示结果 3.完整代码4.性能考虑5.总结 引言 人脸检测是计算机视觉中最基础也最重要的任务之一。今天我将分享如何使用dlib库中的CNN人脸检…...