当前位置: 首页 > news >正文

【深度学习实验】前馈神经网络(九):整合训练、评估、预测过程(Runner)

目录

一、实验介绍

 二、实验环境

1. 配置虚拟环境

2. 库版本介绍

三、实验内容

0. 导入必要的工具包

1. __init__(初始化)

2. train(训练)

3. evaluate(评估)

4. predict(预测)

5. save_model

6. load_model

7. 代码整合


一、实验介绍

      

 二、实验环境

    本系列实验使用了PyTorch深度学习框架,相关操作如下:

1. 配置虚拟环境

conda create -n DL python=3.7 
conda activate DL
pip install torch==1.8.1+cu102 torchvision==0.9.1+cu102 torchaudio==0.8.1 -f https://download.pytorch.org/whl/torch_stable.html
conda install matplotlib
 conda install scikit-learn

2. 库版本介绍

软件包本实验版本目前最新版
matplotlib3.5.33.8.0
numpy1.21.61.26.0
python3.7.16
scikit-learn0.22.11.3.0
torch1.8.1+cu1022.0.1
torchaudio0.8.12.0.2
torchvision0.9.1+cu1020.15.2

三、实验内容

ChatGPT:

        前馈神经网络(Feedforward Neural Network)是一种常见的人工神经网络模型,也被称为多层感知器(Multilayer Perceptron,MLP)。它是一种基于前向传播的模型,主要用于解决分类和回归问题。

        前馈神经网络由多个层组成,包括输入层、隐藏层和输出层。它的名称"前馈"源于信号在网络中只能向前流动,即从输入层经过隐藏层最终到达输出层,没有反馈连接。

以下是前馈神经网络的一般工作原理:

  1. 输入层:接收原始数据或特征向量作为网络的输入,每个输入被表示为网络的一个神经元。每个神经元将输入加权并通过激活函数进行转换,产生一个输出信号。

  2. 隐藏层:前馈神经网络可以包含一个或多个隐藏层,每个隐藏层由多个神经元组成。隐藏层的神经元接收来自上一层的输入,并将加权和经过激活函数转换后的信号传递给下一层。

  3. 输出层:最后一个隐藏层的输出被传递到输出层,输出层通常由一个或多个神经元组成。输出层的神经元根据要解决的问题类型(分类或回归)使用适当的激活函数(如Sigmoid、Softmax等)将最终结果输出。

  4. 前向传播:信号从输入层通过隐藏层传递到输出层的过程称为前向传播。在前向传播过程中,每个神经元将前一层的输出乘以相应的权重,并将结果传递给下一层。这样的计算通过网络中的每一层逐层进行,直到产生最终的输出。

  5. 损失函数和训练:前馈神经网络的训练过程通常涉及定义一个损失函数,用于衡量模型预测输出与真实标签之间的差异。常见的损失函数包括均方误差(Mean Squared Error)和交叉熵(Cross-Entropy)。通过使用反向传播算法(Backpropagation)和优化算法(如梯度下降),网络根据损失函数的梯度进行参数调整,以最小化损失函数的值。

        前馈神经网络的优点包括能够处理复杂的非线性关系,适用于各种问题类型,并且能够通过训练来自动学习特征表示。然而,它也存在一些挑战,如容易过拟合、对大规模数据和高维数据的处理较困难等。为了应对这些挑战,一些改进的网络结构和训练技术被提出,如卷积神经网络(Convolutional Neural Networks)和循环神经网络(Recurrent Neural Networks)等。

本系列为实验内容,对理论知识不进行详细阐释

(咳咳,其实是没时间整理,待有缘之时,回来填坑)

977468b5ae9843c6a88005e792817cb1.png

0. 导入必要的工具包

import torch
from torch import nn
import torch.nn.functional as F
# 绘画时使用的工具包
import matplotlib.pyplot as plt
# 导入鸢尾花数据集
from sklearn.datasets import load_iris
# 构建自己的数据集,继承自Dataset类
from torch.utils.data import Dataset, DataLoader

1. __init__(初始化)

    def __init__(self, model, optimizer, loss_fn, metric, **kwargs):self.model = modelself.optimizer = optimizerself.loss_fn = loss_fn# 用于计算评价指标self.metric = metric# 记录训练过程中的评价指标变化self.dev_scores = []# 记录训练过程中的损失变化self.train_epoch_losses = []self.dev_losses = []# 记录全局最优评价指标self.best_score = 0
  • 五个参数:
    • model(模型)
    • optimizer(优化器)
    • loss_fn(损失函数)
    • metric(评价指标)
    • 其他可选参数。
  • 该类还定义了一些用于记录训练过程中的指标变化和全局最优指标的属性:
    • self.dev_scores(记录验证集评价指标的变化)
    • self.train_epoch_losses(记录训练集损失的变化)
    • self.dev_losses(记录验证集损失的变化)
    • self.best_score(记录全局最优评价指标)

2. train(训练)

 def train(self, train_loader, dev_loader=None, **kwargs):# 将模型设置为训练模式,此时模型的参数会被更新self.model.train()num_epochs = kwargs.get('num_epochs', 0)log_steps = kwargs.get('log_steps', 100)save_path = kwargs.get('save_path', 'best_mode.pth')eval_steps = kwargs.get('eval_steps', 0)# 运行的step数,不等于epoch数global_step = 0if eval_steps:if dev_loader is None:raise RuntimeError('Error: dev_loader can not be None!')if self.metric is None:raise RuntimeError('Error: Metric can not be None')# 遍历训练的轮数for epoch in range(num_epochs):total_loss = 0# 遍历数据集for step, data in enumerate(train_loader):x, y = datalogits = self.model(x.float())loss = self.loss_fn(logits, y.long())total_loss += lossif log_steps and global_step % log_steps == 0:print(f'loss:{loss.item():.5f}')loss.backward()self.optimizer.step()self.optimizer.zero_grad()# 每隔一定轮次进行一次验证,由eval_steps参数控制,可以采用不同的验证判断条件if (epoch + 1) % eval_steps == 0:dev_score, dev_loss = self.evaluate(dev_loader, global_step=global_step)print(f'[Evalute] dev score:{dev_score:.5f}, dev loss:{dev_loss:.5f}')if dev_score > self.best_score:self.save_model(f'model_{epoch + 1}.pth')print(f'[Evaluate]best accuracy performance has been updated: {self.best_score:.5f}-->{dev_score:.5f}')self.best_score = dev_score# 验证过程结束后,请记住将模型调回训练模式self.model.train()global_step += 1# 保存当前轮次训练损失的累计值train_loss = (total_loss / len(train_loader)).item()self.train_epoch_losses.append((global_step, train_loss))print('[Train] Train done')

3. evaluate(评估)

    def evaluate(self, dev_loader, **kwargs):assert self.metric is not None# 将模型设置为验证模式,此模式下,模型的参数不会更新self.model.eval()global_step = kwargs.get('global_step', -1)total_loss = 0self.metric.reset()for batch_id, data in enumerate(dev_loader):x, y = datalogits = self.model(x.float())loss = self.loss_fn(logits, y.long()).item()total_loss += lossself.metric.update(logits, y)dev_loss = (total_loss / len(dev_loader))self.dev_losses.append((global_step, dev_loss))dev_score = self.metric.accumulate()self.dev_scores.append(dev_score)return dev_score, dev_loss

4. predict(预测)

    predict方法用于模型的阶段,输入数据x,返回模型对输入的预测结果。

 def predict(self, x, **kwargs):self.model.eval()logits = self.model(x)return logits

5. save_model

 def save_model(self, save_path):torch.save(self.model.state_dict(),save_path)

6. load_model

  def load_model(self, model_path):self.model.load_state_dict(torch.load(model_path, map_location=torch.device('cpu')))

7. 代码整合

class Runner(object):def __init__(self, model, optimizer, loss_fn, metric, **kwargs):self.model = modelself.optimizer = optimizerself.loss_fn = loss_fn# 用于计算评价指标self.metric = metric# 记录训练过程中的评价指标变化self.dev_scores = []# 记录训练过程中的损失变化self.train_epoch_losses = []self.dev_losses = []# 记录全局最优评价指标self.best_score = 0# 模型训练阶段def train(self, train_loader, dev_loader=None, **kwargs):# 将模型设置为训练模式,此时模型的参数会被更新self.model.train()num_epochs = kwargs.get('num_epochs', 0)log_steps = kwargs.get('log_steps', 100)save_path = kwargs.get('save_path','best_mode.pth')eval_steps = kwargs.get('eval_steps', 0)# 运行的step数,不等于epoch数global_step = 0if eval_steps:if dev_loader is None:raise RuntimeError('Error: dev_loader can not be None!')if self.metric is None:raise RuntimeError('Error: Metric can not be None')# 遍历训练的轮数for epoch in range(num_epochs):total_loss = 0# 遍历数据集for step, data in enumerate(train_loader):x, y = datalogits = self.model(x.float())loss = self.loss_fn(logits, y.long())total_loss += lossif log_steps and global_step%log_steps == 0:print(f'loss:{loss.item():.5f}')loss.backward()self.optimizer.step()self.optimizer.zero_grad()# 每隔一定轮次进行一次验证,由eval_steps参数控制,可以采用不同的验证判断条件if (epoch+1)% eval_steps ==  0:dev_score, dev_loss = self.evaluate(dev_loader, global_step=global_step)print(f'[Evalute] dev score:{dev_score:.5f}, dev loss:{dev_loss:.5f}')if dev_score > self.best_score:self.save_model(f'model_{epoch+1}.pth')print(f'[Evaluate]best accuracy performance has been updated: {self.best_score:.5f}-->{dev_score:.5f}')self.best_score = dev_score# 验证过程结束后,请记住将模型调回训练模式   self.model.train()global_step += 1# 保存当前轮次训练损失的累计值train_loss = (total_loss/len(train_loader)).item()self.train_epoch_losses.append((global_step,train_loss))print('[Train] Train done')# 模型评价阶段def evaluate(self, dev_loader, **kwargs):assert self.metric is not None# 将模型设置为验证模式,此模式下,模型的参数不会更新self.model.eval()global_step = kwargs.get('global_step',-1)total_loss = 0self.metric.reset()for batch_id, data in enumerate(dev_loader):x, y = datalogits = self.model(x.float())loss = self.loss_fn(logits, y.long()).item()total_loss += loss self.metric.update(logits, y)dev_loss = (total_loss/len(dev_loader))self.dev_losses.append((global_step, dev_loss))dev_score = self.metric.accumulate()self.dev_scores.append(dev_score)return dev_score, dev_loss# 模型预测阶段,def predict(self, x, **kwargs):self.model.eval()logits = self.model(x)return logits# 保存模型的参数def save_model(self, save_path):torch.save(self.model.state_dict(),save_path)# 读取模型的参数def load_model(self, model_path):self.model.load_state_dict(torch.load(model_path, map_location=torch.device('cpu')))

相关文章:

【深度学习实验】前馈神经网络(九):整合训练、评估、预测过程(Runner)

目录 一、实验介绍 二、实验环境 1. 配置虚拟环境 2. 库版本介绍 三、实验内容 0. 导入必要的工具包 1. __init__(初始化) 2. train(训练) 3. evaluate(评估) 4. predict(预测) 5. save_model 6. load_model 7. 代码整合 一、实验介绍 二、实验环境 本系列实验使用…...

002-第一代硬件系统架构确立及产品选型

第一代硬件系统架构确立及产品选型 文章目录 第一代硬件系统架构确立及产品选型项目介绍摘要硬件架构硬件结构选型及设计单片机选型上位机选型扯点别的 关键字: Qt、 Qml、 信号采集机、 数据处理、 上位机 项目介绍 欢迎来到我们的 QML & C 项目&#xff…...

Go基础语法:指针和make和new

8 指针、make、new 8.1 指针(pointer) Go 语言中没有指针操作,只需要记住两个符号即可: & 取内存地址* 根据地址取值 package mainimport "fmt"func main() {a : 18// 获取 a 的地址值并复制给 pp : &a// …...

039_小驰私房菜_Camera perfermance debug

全网最具价值的Android Camera开发学习系列资料~ 作者:8年Android Camera开发,从Camera app一直做到Hal和驱动~ 欢迎订阅,相信能扩展你的知识面,提升个人能力~ 一、抓取trace 1. adb shell "echo vendor.debug.trace.perf=1 >> /system/build.prop" 2. …...

Caché for Windows安装及配置

本文介绍在Windows上安装Cach的操作步骤。本文假设用户熟悉Windows目录结构、实用程序和命令。本文包含如下主要部分:​​​​​​ 1)Cach安装...

代码随想录算法训练营20期|第四十六天|动态规划part08|● 139.单词拆分 ● 关于多重背包,你该了解这些! ● 背包问题总结篇!

139.单词拆分 感觉这个板块要重新刷&#xff0c;完全没有印象 class Solution {public boolean wordBreak(String s, List<String> wordDict) {Set<String> set new HashSet<>(wordDict);boolean[] dp new boolean[s.length() 1];dp[0] true;for (int i…...

系统安装(一)CentOS 7 本地安装

CentOS与Ubuntu并称为Linux最著名的两个发行版&#xff0c;但由于笔者主要从事深度学习图像算法工作&#xff0c;Ubuntu作为谷歌和多数依赖库的亲儿子占据着最高生态位。但最近接手的一个项目里&#xff0c;甲方指定需要在CentOS7上运行项目代码&#xff0c;笔者被迫小小cos了一…...

obsidian使用指南

插入代码块快捷键设置 插入代码块 用英文搜索快捷键名字 英文搜索的【Insert code block】对应的是 (6个点) 中文搜索的【代码块】对应的是 &#xff08;2个点&#xff09; 查看word、excel等非md文件设置 电脑端obsidian->设置->文件与链接->检测所有类型文件->…...

【ardunio】青少年机器人四级实操代码(2023年9月)

目录 一、题目 二、示意图 三、流程图 四、硬件连接 1、舵机 2、超声波 3、LED灯 五、程序 一、题目 实操考题(共1题&#xff0c;共100分) 1. 主题&#xff1a; 迎宾机器人 器件&#xff1a;Atmega328P主控板1块&#xff0c;舵机1个&#xff0c;超声波传感器1个&…...

MYSQL的存储过程

存储过程 存储过程是事先经过编译并存储在数据库中的一段 SQL 语句的集合&#xff0c;调用存储过程可以简化应用开发人员的很多工作&#xff0c;减少数据在数据库和应用服务器之间的传输&#xff0c;对于提高数据处理的效率是有好处的。存储过程思想上很简单&#xff0c;就是…...

[kubernetes/docker] failed to resolve reference ...:latest: not found

问题描述: pod一直pending, kubectl describe pod ... 显示: Warning Failed 9s (x3 over 63s) kubelet Failed to pull image "mathemagics/my-kube-scheduler": rpc error: code NotFound desc failed to pull and unpack image "docker…...

彻底解决win11系统0x80070032

经过各种尝试&#xff0c;终于找到原因。第一个是电脑加密软件&#xff0c;第二个是需要的部分功能没有开启&#xff0c;第三个BIOS设置。个人觉得第三个不重要。 解决方法 笔记本型号 笔记本型号是Thinkpad T14 gen2。进入BIOS的按键是按住Enter键。 1、关闭山丽防水墙服务…...

解决因为修改SELINUX配置文件出错导致Faild to load SELinux poilcy无法进入CentOS7系统的问题

一、问题 最近学习Kubernetes&#xff0c;需要设置永久关闭SELINUX,结果修改错了一个SELINUX配置参数&#xff0c;关机重新启动后导致无法进入CentOS7系统&#xff0c;卡在启动进度条界面。 二、解决 多次重启后&#xff0c;在启动日志中发现 Faild to load SELinux poilcy…...

flask中的跨域处理-方法二不使用第三方库

方法1(第三方库) pip install flask-cors from flask import Flask from flask_cors import CORSapp = Flask(__name__) CORS(app, resources={r"/api/*": {"origins": ["http://localhost:63342", "http://localhost:63345"]}})方…...

矿山定位系统-矿井人员定位系统在矿山自动化安全监控过程中的应用

一&#xff0c;矿井人员定位系统现阶段使用的必要性 1&#xff0c;煤矿开采是一项非常特殊的工作&#xff0c;现场属于非常复杂多变的环境&#xff0c;井下信号极差&#xff0c;数据传输非常不稳定&#xff0c;人员安全难以保证&#xff0c;煤矿企业一直在研究如何使用更合适的…...

JS-ECharts-前端图表 多层级联合饼图、柱状堆叠图、柱/线组合图、趋势图、自定义中线、平均线、气泡备注点

本篇博客背景为JavaScript。在ECharts在线编码快速上手&#xff0c;绘制相关前端可视化图表。 ECharts官网&#xff1a;https://echarts.apache.org/zh/index.html 其他的一些推荐&#xff1a; AntV&#xff1a;https://antv.vision/zh chartcube&#xff1a;https://chartcub…...

【eslint】屏蔽语言提醒

在 JavaScript 中&#xff0c;ESLint 是一种常用的静态代码分析工具&#xff0c;它用于检测和提醒代码中的潜在问题和风格问题。有时候&#xff0c;在某些特定情况下&#xff0c;你可能希望临时屏蔽或禁用某些 ESLint 的提醒信息&#xff0c;以便消除不必要的警告或避免不符合项…...

【python】入门第一课:了解基本语法(数据类型)

目录 一、介绍 1、什么是python&#xff1f; 2、python的几个特点 二、实例 1、注释 2、数据类型 2.1、字符串 str 2.2、整数 int 2.3、浮点数 float 2.4、布尔 bool 2.5、列表 list 2.6、元组 tuple 2.7、集合 set 2.8、字典 dict 一、介绍 1、什么是python&…...

csa从初阶到大牛(练习题2-查询)

新建2个文件d1.txt d2.txt ,使用vim打开d1.txt 输入“Hello World”字符串,将b1.txt 硬链接到b2.txt &#xff0c;查看2个文件的硬连接数 # 新建文件d1.txt和d2.txt touch d1.txt d2.txt# 使用vim编辑d1.txt并输入文本"Hello World" vim d1.txt# 创建硬链接b2.…...

【视觉SLAM入门】8. 回环检测,词袋模型,字典,感知,召回,机器学习

"见人细过 掩匿盖覆” 1. 意义2. 做法2.1 词袋模型和字典2.1.2 感知偏差和感知变异2.1.2 词袋2.1.3 字典 2.2 匹配(相似度)计算 3. 提升 前言&#xff1a; 前端提取数据&#xff0c;后端优化数据&#xff0c;但误差会累计&#xff0c;需要回环检测构建全局一致的地图&…...

装饰模式(Decorator Pattern)重构java邮件发奖系统实战

前言 现在我们有个如下的需求&#xff0c;设计一个邮件发奖的小系统&#xff0c; 需求 1.数据验证 → 2. 敏感信息加密 → 3. 日志记录 → 4. 实际发送邮件 装饰器模式&#xff08;Decorator Pattern&#xff09;允许向一个现有的对象添加新的功能&#xff0c;同时又不改变其…...

React第五十七节 Router中RouterProvider使用详解及注意事项

前言 在 React Router v6.4 中&#xff0c;RouterProvider 是一个核心组件&#xff0c;用于提供基于数据路由&#xff08;data routers&#xff09;的新型路由方案。 它替代了传统的 <BrowserRouter>&#xff0c;支持更强大的数据加载和操作功能&#xff08;如 loader 和…...

华为OD机试-食堂供餐-二分法

import java.util.Arrays; import java.util.Scanner;public class DemoTest3 {public static void main(String[] args) {Scanner in new Scanner(System.in);// 注意 hasNext 和 hasNextLine 的区别while (in.hasNextLine()) { // 注意 while 处理多个 caseint a in.nextIn…...

spring:实例工厂方法获取bean

spring处理使用静态工厂方法获取bean实例&#xff0c;也可以通过实例工厂方法获取bean实例。 实例工厂方法步骤如下&#xff1a; 定义实例工厂类&#xff08;Java代码&#xff09;&#xff0c;定义实例工厂&#xff08;xml&#xff09;&#xff0c;定义调用实例工厂&#xff…...

【Go】3、Go语言进阶与依赖管理

前言 本系列文章参考自稀土掘金上的 【字节内部课】公开课&#xff0c;做自我学习总结整理。 Go语言并发编程 Go语言原生支持并发编程&#xff0c;它的核心机制是 Goroutine 协程、Channel 通道&#xff0c;并基于CSP&#xff08;Communicating Sequential Processes&#xff0…...

项目部署到Linux上时遇到的错误(Redis,MySQL,无法正确连接,地址占用问题)

Redis无法正确连接 在运行jar包时出现了这样的错误 查询得知问题核心在于Redis连接失败&#xff0c;具体原因是客户端发送了密码认证请求&#xff0c;但Redis服务器未设置密码 1.为Redis设置密码&#xff08;匹配客户端配置&#xff09; 步骤&#xff1a; 1&#xff09;.修…...

uniapp手机号一键登录保姆级教程(包含前端和后端)

目录 前置条件创建uniapp项目并关联uniClound云空间开启一键登录模块并开通一键登录服务编写云函数并上传部署获取手机号流程(第一种) 前端直接调用云函数获取手机号&#xff08;第三种&#xff09;后台调用云函数获取手机号 错误码常见问题 前置条件 手机安装有sim卡手机开启…...

Sklearn 机器学习 缺失值处理 获取填充失值的统计值

💖亲爱的技术爱好者们,热烈欢迎来到 Kant2048 的博客!我是 Thomas Kant,很开心能在CSDN上与你们相遇~💖 本博客的精华专栏: 【自动化测试】 【测试经验】 【人工智能】 【Python】 使用 Scikit-learn 处理缺失值并提取填充统计信息的完整指南 在机器学习项目中,数据清…...

jdbc查询mysql数据库时,出现id顺序错误的情况

我在repository中的查询语句如下所示&#xff0c;即传入一个List<intager>的数据&#xff0c;返回这些id的问题列表。但是由于数据库查询时ID列表的顺序与预期不一致&#xff0c;会导致返回的id是从小到大排列的&#xff0c;但我不希望这样。 Query("SELECT NEW com…...

如何在Windows本机安装Python并确保与Python.NET兼容

✅作者简介&#xff1a;2022年博客新星 第八。热爱国学的Java后端开发者&#xff0c;修心和技术同步精进。 &#x1f34e;个人主页&#xff1a;Java Fans的博客 &#x1f34a;个人信条&#xff1a;不迁怒&#xff0c;不贰过。小知识&#xff0c;大智慧。 &#x1f49e;当前专栏…...