当前位置: 首页 > article >正文

DAY33MLP神经网络的训练

一、 核心知识点回顾1. 环境配置基础核心操作PyTorch 与 CUDA 的安装、验证及环境排查。关键命令查看显卡信息nvidia-smiCMD 中使用。CUDA 检查验证 PyTorch 是否能调用 GPU 加速.cuda()。2. MLP 训练全流程PyTorch 标准五步法步骤核心内容关键要点a. 数据预处理归一化、转换为张量Tensor将数据转化为模型可计算的数值格式为训练做准备。b. 模型定义继承nn.Module类1. 构建网络层如全连接层nn.Linear。2. 编写forward前向传播逻辑。c. 损失与优化定义损失函数、优化器分类任务常用交叉熵损失回归任务常用 MSE 损失优化器如 SGD、Adam。d. 训练流程迭代训练前向传播 ➔ 计算损失 ➔ 反向传播 ➔ 参数更新。e. 可视化绘制 Loss 曲线监控训练过程判断模型是否收敛或过拟合。二、 关键注意事项避坑指南这是实际编码中极易出错的细节必须严格遵守数据类型规范分类任务标签Label必须转换为LongTensor类型。原因交叉熵损失函数CrossEntropyLoss要求目标标签为整数索引若传入 Float 类型会报错。回归任务标签Label必须转换为FloatTensor类型通常为torch.float32。原因回归预测的是连续数值需保持与输出数据类型一致。三、 模型设计思路当前设定图片中提到选择了2 层隐藏层且固定神经元数量。类比理解这类似于传统机器学习中指定超参数是一种基础的网络结构设定。未来方向文中提到 “调参我们未来再提”暗示后续会涉及神经元数量调整、网络层数加深等更复杂的模型优化内容。总结该内容处于深度学习实战入门阶段重点在于规范搭建 MLP 训练流程。核心在于环境打通与数据类型匹配只要保证标签类型正确就能顺利完成第一次神经网络的训练与可视化。import torch import torch.nn as nn import torch.optim as optim import numpy as np import matplotlib.pyplot as plt # 1. 环境配置与数据准备 # 检查CUDA是否可用自动选择设备GPU/CPU device torch.device(cuda if torch.cuda.is_available() else cpu) print(f使用设备: {device}) # 生成模拟数据集分类任务 # 输入特征维度10类别数5样本数1000 input_dim 10 num_classes 5 sample_num 1000 # 生成随机特征float32类型 x torch.randn(sample_num, input_dim, dtypetorch.float32).to(device) # 分类任务标签必须是LongTensor整数类型 y_classification torch.randint(0, num_classes, (sample_num,), dtypetorch.long).to(device) # 回归任务标签必须是FloatTensor浮点类型 y_regression torch.randn(sample_num, 1, dtypetorch.float32).to(device) # 2. 定义MLP模型2层隐藏层 class MLP(nn.Module): def __init__(self, input_dim, hidden_dim164, hidden_dim232, output_dim1, task_typeclassification): super(MLP, self).__init__() # 2层隐藏层固定神经元数量后续可调参 self.fc1 nn.Linear(input_dim, hidden_dim1) # 第一层隐藏层 self.fc2 nn.Linear(hidden_dim1, hidden_dim2) # 第二层隐藏层 self.fc3 nn.Linear(hidden_dim2, output_dim) # 输出层 self.relu nn.ReLU() # 激活函数 self.task_type task_type def forward(self, x): # 前向传播 out self.relu(self.fc1(x)) out self.relu(self.fc2(out)) out self.fc3(out) # 分类任务输出层不加激活CrossEntropyLoss内置Softmax if self.task_type classification and output_dim 1: return out # 回归任务直接输出连续值 return out # 3. 初始化模型、损失函数、优化器 # 分类任务配置 model_class MLP(input_diminput_dim, output_dimnum_classes, task_typeclassification).to(device) criterion_class nn.CrossEntropyLoss() # 分类损失 optimizer_class optim.Adam(model_class.parameters(), lr0.001) # 优化器 # 回归任务配置可选 # model_reg MLP(input_diminput_dim, output_dim1, task_typeregression).to(device) # criterion_reg nn.MSELoss() # 回归损失 # optimizer_reg optim.Adam(model_reg.parameters(), lr0.001) # 4. 训练流程 epochs 100 # 训练轮数 loss_history [] # 记录损失变化 model_class.train() # 切换到训练模式 for epoch in range(epochs): # 前向传播 outputs model_class(x) loss criterion_class(outputs, y_classification) # 反向传播 参数更新 optimizer_class.zero_grad() # 清空梯度 loss.backward() # 反向传播 optimizer_class.step() # 更新参数 # 记录损失 loss_history.append(loss.item()) # 每10轮打印一次 if (epoch 1) % 10 0: print(fEpoch [{epoch1}/{epochs}], Loss: {loss.item():.4f}) # 5. 可视化Loss曲线 plt.plot(loss_history) plt.xlabel(Epoch) plt.ylabel(Loss) plt.title(MLP Training Loss Curve) plt.grid(True) plt.show() # 6. 简单验证 model_class.eval() # 切换到评估模式 with torch.no_grad(): # 禁用梯度计算加速 test_x torch.randn(10, input_dim, dtypetorch.float32).to(device) pred model_class(test_x) pred_label torch.argmax(pred, dim1) # 取概率最大的类别 print(\n测试样本预测结果类别索引, pred_label.cpu().numpy())代码关键部分解释环境配置自动检测 CUDA优先使用 GPU 加速对应图片中nvidia-smi和 CUDA 验证若没有 GPU自动降级到 CPU 运行不影响核心功能。数据类型严格匹配分类任务标签y_classification用torch.longLongTensor解决交叉熵损失的类型报错问题回归任务标签y_regression用torch.float32FloatTensor符合回归任务的数值类型要求。MLP 模型结构严格按照图片要求设置2 层隐藏层fc1、fc2神经元数量默认 64/32后续可调参激活函数用 ReLU深度学习常用输出层根据任务类型适配分类不加激活回归直接输出。训练五步法前向传播 → 计算损失 → 清空梯度 → 反向传播 → 更新参数完全匹配 PyTorch 标准训练流程。可视化绘制 Loss 曲线直观监控模型收敛情况对应图片中 “可视化” 要求。运行前置条件安装依赖包pip install torch numpy matplotlib若要使用 GPU需确保电脑有 NVIDIA 显卡安装对应版本的 CUDA 和 cuDNNPyTorch 版本与 CUDA 版本匹配无需手动验证代码会自动检测。总结核心重点分类任务标签用 LongTensor回归用 FloatTensor这是避免训练报错的关键MLP 训练流程环境配置→数据准备→模型定义→损失 / 优化器→训练迭代→可视化是 PyTorch 深度学习的通用模板扩展性当前隐藏层神经元数量固定后续可通过调整hidden_dim1/hidden_dim2实现调参优化。浙大疏锦行

相关文章:

DAY33MLP神经网络的训练

一、 核心知识点回顾 1. 环境配置基础 核心操作:PyTorch 与 CUDA 的安装、验证及环境排查。关键命令: 查看显卡信息:nvidia-smi(CMD 中使用)。CUDA 检查:验证 PyTorch 是否能调用 GPU 加速(.c…...

毕业设计救星:手把手教你用KF-GINS搞定GNSS/INS松组合导航(附代码避坑)

毕业设计实战:从零搭建GNSS/INS松组合导航系统 第一次接触KF-GINS时,我被那些复杂的矩阵运算和坐标系转换搞得晕头转向。作为导航专业的毕业生,我完全理解那种面对开源代码手足无措的感觉——明明知道卡尔曼滤波很重要,但看到满屏…...

欧姆龙CP1H脉冲程序案例及新手入门指南

A1欧姆龙CP1H程序 姆龙标准程序 欧姆龙PLC标准案例模板 本产品适用于新手或者在校生 本程序主要写了欧姆龙CP1H脉冲程序案例, 包含以下: 威纶通触摸屏程序; word详细说明文档 ; 欧姆龙CP1H程序; 里面的文档有详细介绍…...

Turtlebot3+Nav2实战:手把手教你用RVIZ实现室内SLAM建图(避坑指南)

Turtlebot3Nav2实战:从零实现室内SLAM建图的避坑指南 当第一次看到Turtlebot3在未知环境中自主构建地图时,那种科技带来的震撼感至今难忘。作为ROS2生态中最受欢迎的入门级机器人平台,Turtlebot3配合Nav2导航栈能够实现令人惊艳的SLAM建图效果…...

RRT+人工势场法路径规划与APF应用

融合RRT和人工势场法 路径规划 rrt apf 具有开关设置路径规划领域有个经典难题:如何在复杂环境中快速找到安全路径?RRT(快速扩展随机树)和人工势场法这对CP最近被我玩出了新花样。咱们今天不聊理论公式,直接上代码说人…...

别再自己造轮子了!用Three.js的TubeGeometry在Cesium里画空心管道(附完整Vue3代码)

跨引擎三维可视化:用Three.js几何体增强Cesium场景渲染 在三维地理信息系统开发中,Cesium和Three.js都是不可或缺的技术栈。Cesium擅长全球尺度的地理空间可视化,而Three.js则提供了丰富的几何体生成能力。当我们需要在Cesium中实现复杂几何…...

Comsol仿真超表面复现:多级分解通用适用于各种形状,六面体阵列与圆柱体阵列复现相吻合,多物...

comsol仿真超表面复现:多级分解通用,适用各种形状,以下是两篇文献(六面体阵列、圆柱体阵列)的复现都相吻合 多物理场仿真耦合有限元模拟comsol,提供建模思路,包括流体、力学、传热、电磁等 玩C…...

Qwen2-VL-2B-Instruct模型压缩与量化教程:在边缘设备部署视觉语言模型

Qwen2-VL-2B-Instruct模型压缩与量化教程:在边缘设备部署视觉语言模型 想让一个能看懂图片、还能跟你聊天的AI模型,在你的树莓派或者开发板上跑起来吗?听起来有点天方夜谭,毕竟这类视觉语言模型通常都是“大块头”,对…...

OpenClaw - Personal AI Assistant (个人 AI 助理)

OpenClaw - Personal AI Assistant {个人 AI 助理} 1. OpenClaw - Personal AI Assistant2. OpenClaw2.1. Docs2.2. Mattermost 3. ConclusionsReferences OpenClaw (formerly Clawdbot, Moltbot, and Molty) is a free and open-source autonomous artificial intelligence ag…...

带隙基准Bandgap与低压差稳压器Ldo电路

带隙基准Bandgap,低压差稳压器Ldo电路在模拟电路设计中,稳定的电压源是许多系统的基石。带隙基准(Bandgap)和低压差稳压器(LDO)这对黄金搭档,一个负责生成精准电压,另一个负责在恶劣…...

RT-Thread实战:STM32硬件看门狗配置与多任务喂狗策略详解

RT-Thread实战:STM32硬件看门狗配置与多任务喂狗策略详解 在嵌入式系统开发中,系统稳定性是至关重要的考量因素。当系统运行在复杂电磁环境或长时间无人值守的场景时,硬件看门狗(Watchdog)成为保障系统可靠性的最后一道…...

做了一个 AI 鸿蒙 App,我发现逻辑变了

子玥酱 (掘金 / 知乎 / CSDN / 简书 同名) 大家好,我是 子玥酱,一名长期深耕在一线的前端程序媛 👩‍💻。曾就职于多家知名互联网大厂,目前在某国企负责前端软件研发相关工作,主要聚…...

【暖洋葱家庭教育有效果吗】用数据说话:暖洋葱发布年度服务报告,家长满意度高达96.3%

“孩子沉迷手机,说了不听,打又没用,暖洋葱真的能帮我吗?”这是许多家长在咨询时最关心的问题。面对家长的期待,暖洋葱家庭教育坚信:教育不能仅靠口号,效果必须经得起检验。近日,暖洋…...

基于深度学习预测+MPC的车辆轨迹跟踪自动驾驶汽车预测控制Matlab仿真(带参考文献)

✅作者简介:热爱科研的Matlab仿真开发者,擅长毕业设计辅导、数学建模、数据处理、建模仿真、程序设计、完整代码获取、论文复现及科研仿真。 🍎 往期回顾关注个人主页:Matlab科研工作室 👇 关注我领取海量matlab电子…...

现代智能汽车系统——照明系统0

摘要:车辆灯具按功能分为四大类:1)外部照明灯(远近光灯、雾灯等),用于道路照明;2)外部信号灯(转向灯、刹车灯等),用于车辆间通信;3&am…...

UI-TARS-desktop完整指南:vLLM高性能推理+Qwen3-4B-Instruct多模态任务闭环实践

UI-TARS-desktop完整指南:vLLM高性能推理Qwen3-4B-Instruct多模态任务闭环实践 想找一个开箱即用、能看能说、还能帮你操作电脑的AI助手吗?今天要介绍的UI-TARS-desktop,就是一个集成了高性能vLLM推理引擎和强大Qwen3-4B-Instruct多模态模型…...

Laravel7.x十大核心特性解析

Laravel 7.x 版本引入了多项重要特性与优化,以下是核心特性概述: 1. 路由签名语法优化 新增 Route::signed() 和 Route::temporarySigned() 方法,简化签名 URL 的生成与验证: // 生成签名路由 Route::signed(verify, Verificati…...

无速度传感器DTC实战:让电机自己“报“转速

基于MRAS的异步电机直接转矩控制/基于转子磁链模型的MRAS转速辨识/基于反电动势模型的MRAS转速辨识/基于无功功率模型的MRAS转速辨识 在simulink搭建的异步电机模型预测转矩控制模型之上进行改进,把转速环中实际转速从测量值更换为MARS观测器的转速估计值&#xff0…...

保姆级教程:JCG Q30 Pro免拆刷OpenWrt 24.10(附常见问题排查)

JCG Q30 Pro免拆刷OpenWrt 24.10全流程指南与深度优化 为什么选择OpenWrt与JCG Q30 Pro的完美组合 在智能家居和网络设备高度发达的今天,路由器早已不再是简单的网络连接设备。对于技术爱好者而言,一台能够自由定制、性能强劲的路由器,就像…...

AI简历姬支持上传JD后逐段改写简历吗?

摘要 是的,AI简历姬支持上传JD后逐段改写简历。其核心工作流程是:上传或粘贴JD -> 解析JD关键词 -> 将你的现有经历与岗位要求逐项对齐 -> 提供匹配度评分、缺口清单和具体的改写建议。这不同于简单的文案润色,而是围绕“岗位要求 -…...

基于OpenFast联合仿真的独立变桨与统一变桨风电机组控制模型

openfast与simlink联合仿真模型,风电机组独立变桨控制与统一变桨控制。 独立变桨控制。 OpenFast联合仿真。 基于载荷反馈的独立变桨控制 风机变桨控制基于FAST与MATLAB SIMULINK联合仿真模型的非线性风力发电机的PID独立变桨和统一变桨控制下仿真模型。 5MW非线性风…...

MLX90632红外温度传感器Arduino驱动库详解

1. ProtoCentral MLX90632 非接触式红外温度传感器库深度解析1.1 项目定位与工程价值ProtoCentral MLX90632 库是专为 Melexis MLX90632 红外非接触温度传感器设计的 Arduino 兼容驱动库,面向嵌入式系统工程师、硬件开发者及电子爱好者提供开箱即用的高精度测温能力…...

VMware Workstation Pro 17 安装 VyOS 软路由保姆级教程(附镜像下载)

VMware Workstation Pro 17 安装 VyOS 软路由全流程指南 在家庭网络或小型办公环境中部署软路由正逐渐成为技术爱好者和IT从业者的新选择。VyOS作为一款基于Linux的开源网络操作系统,以其轻量级、高性能和丰富的网络功能吸引了大量用户。本文将详细介绍如何在Window…...

python+flask+vue3基于web的社区物业管理平台开题

目录技术选型与架构设计项目模块划分开发环境搭建关键API设计示例前端数据交互数据模型设计开发进度安排测试策略部署方案项目技术支持源码获取详细视频演示 :文章底部获取博主联系方式!同行可合作技术选型与架构设计 后端采用PythonFlask框架&#xff…...

线性代数实战指南:从线性空间基础到高阶应用解析

1. 线性空间:从抽象定义到现实世界 第一次接触线性空间这个概念时,我也被那些抽象的定义搞得头晕眼花。直到有一天在玩3D游戏时突然意识到,游戏里角色的移动、旋转和缩放,本质上都是在操作线性空间中的向量。这才明白线性空间不是…...

【中等】将整数字符串转成整数值-Java

分享一个大牛的人工智能教程。零基础!通俗易懂!风趣幽默!希望你也加入到人工智能的队伍中来!请轻击人工智能教程大家好!欢迎来到我的网站! 人工智能被认为是一种拯救世界、终结世界的技术。毋庸置疑&#x…...

VMware Workstation Pro 17安装openEuler24.03 LTS避坑指南:从镜像下载到网络配置

VMware Workstation Pro 17 安装 openEuler 24.03 LTS 全流程实战与深度优化 作为一款面向数字基础设施的开源操作系统,openEuler 24.03 LTS 凭借其安全稳定、高效易用的特性,正成为企业级应用的新选择。本文将基于VMware Workstation Pro 17虚拟化环境&…...

306. 累加数(dfs回溯)

链接&#xff1a;306. 累加数 题解&#xff1a; class Solution { public:bool isAdditiveNumber(string num) {if (num.size() < 2) {return false;}int begin 0;std::vector<uint64_t> path;return dfs(begin, num, path);}bool dfs(int begin, const std::strin…...

ELF文件格式解析:嵌入式ARM固件的链接、加载与执行机制

1. ELF 文件规范与嵌入式系统二进制格式演进Executable and Linking Format&#xff08;ELF&#xff09;是一种定义明确、高度可扩展的二进制文件格式规范&#xff0c;其核心目标是为不同阶段的软件生命周期——从源码编译、目标文件链接到最终程序加载执行——提供统一、可移植…...

PHP-Resque部署指南:生产环境配置与监控方案

PHP-Resque部署指南&#xff1a;生产环境配置与监控方案 【免费下载链接】php-resque PHP port of resque (Workers and Queueing) 项目地址: https://gitcode.com/gh_mirrors/ph/php-resque PHP-Resque是一个功能强大的PHP任务队列系统&#xff0c;允许开发者将耗时任务…...