当前位置: 首页 > news >正文

现代卷积网络实战系列2:训练函数、PyTorch构建LeNet网络

4、训练函数

4.1 调用训练函数

train(epochs, net, train_loader, device, optimizer, test_loader, true_value)

因为每一个epoch训练结束后,我们需要测试一下这个网络的性能,所有会在训练函数中频繁调用测试函数,所有测试函数中所有需要的参数,训练函数都需要
这七个参数,是训练一个神经网络所需要的最少参数

4.2 训练函数

训练函数中,所有训练集进行多次迭代,而每次迭代又会将数据分成多个批次进行迭代

def train(epochs, net, train_loader, device, optimizer, test_loader, true_value):for epoch in range(1, epochs + 1):net.train()all_train_loss = []for batch_idx, (data, target) in enumerate(train_loader):data = data.to(device)target = target.to(device)optimizer.zero_grad()output = net(data)loss = F.cross_entropy(output, target)loss.backward()optimizer.step()cur_train_loss = loss.item()all_train_loss.append(cur_train_loss)train_loss = np.round(np.mean(all_train_loss) * 1000, 2)print('\nepoch step:', epoch)print('training loss: ', train_loss)test(net, test_loader, device, true_value, epoch)print("\nTraining finished")
  1. 定义训练函数
  2. 安装epochs迭代数据
  3. 进入pytorch的训练模式
  4. all_train_loss 存放训练集5万张图片的损失值
  5. 按照batch取数据
  6. 数据进入GPU
  7. 标签进入GPU
  8. 梯度清零
  9. 当前batch进入网络后得到输出
  10. 根据输出得到当前损失
  11. 反向传播
  12. 梯度下降
  13. 获取损失的损失值(PyTorch框架中的数据)
  14. 把当前batch的损失加入all_train_loss数组中,结束batch的迭代
  15. 将5张图片的损失计算出来并且进行求平均,这里乘以1000是因为我觉得计算出的损失太小了,所以乘以1000,方便看损失的变化,保留两位有效数字
  16. 打印当前epoch
  17. 打印损失
  18. 调用测试函数,测试当前训练的网络的性能,结束epoch的迭代
  19. 打印训练完成

5、LeNet

5.1 网络结构

LeNet可以说是首次提出卷积神经网络的模型
主要包含下面的网络层:

  1. 5*5的二维卷积
  2. sigmoid激活函数(这里使用了relu)
  3. 5*5的二维卷积
  4. sigmoid激活函数
  5. 数据一维化
  6. 全连接层
  7. 全连接层
  8. softmax分类器

将网络结构打印出来:

LeNet(
-------(conv1): Conv2d(1, 10, kernel_size=(5, 5), stride=(1, 1))
-------(conv2): Conv2d(10, 20, kernel_size=(5, 5), stride=(1, 1))
-------(conv2_drop): Dropout2d(p=0.5, inplace=False)
-------(fc1): Linear(in_features=320, out_features=50, bias=True)
-------(fc2): Linear(in_features=50, out_features=10, bias=True)
)

5.2 PyTorch构建LeNet

class LeNet(nn.Module):def __init__(self, num_classes):super(LeNet, self).__init__()self.conv1 = nn.Conv2d(1, 10, kernel_size=5)self.conv2 = nn.Conv2d(10, 20, kernel_size=5)self.conv2_drop = nn.Dropout2d()self.fc1 = nn.Linear(320, 50)self.fc2 = nn.Linear(50, num_classes)def forward(self, x):x = F.relu(F.max_pool2d(self.conv1(x), 2))x = F.relu(F.max_pool2d(self.conv2_drop(self.conv2(x)), 2))x = x.view(-1, 320)x = F.relu(self.fc1(x))x = F.dropout(x, training=self.training)x = self.fc2(x)return F.log_softmax(x, dim=1)

这个时候已经是一个完整的项目了,看看10个epoch训练过程的打印:

D:\conda\envs\pytorch\python.exe A:\0_MNIST\train.py

Reading data…
train_data: (60000, 28, 28) train_label (60000,)
test_data: (10000, 28, 28) test_label (10000,)

Initialize neural network
test loss: 2301.68
test accuracy: 11.3 %

epoch step: 1
training loss: 634.74
test loss: 158.03
test accuracy: 95.29 %

epoch step: 2
training loss: 324.04
test loss: 107.62
test accuracy: 96.55 %

epoch step: 3
training loss: 271.25
test loss: 88.43
test accuracy: 97.04 %

epoch step: 4
training loss: 236.69
test loss: 70.94
test accuracy: 97.61 %

epoch step: 5
training loss: 211.05
test loss: 69.69
test accuracy: 97.72 %

epoch step: 6
training loss: 199.28
test loss: 62.04
test accuracy: 97.98 %

epoch step: 7
training loss: 187.11
test loss: 59.65
test accuracy: 97.98 %

epoch step: 8
training loss: 178.79
test loss: 53.89
test accuracy: 98.2 %

epoch step: 9
training loss: 168.75
test loss: 51.83
test accuracy: 98.43 %

epoch step: 10
training loss: 160.83
test loss: 50.35
test accuracy: 98.4 %

Training finished
进程已结束,退出代码为 0

可以看出基本上只要一个epoch就可以得到很好的训练效果了,后续的epoch中的提升比较小

相关文章:

现代卷积网络实战系列2:训练函数、PyTorch构建LeNet网络

4、训练函数 4.1 调用训练函数 train(epochs, net, train_loader, device, optimizer, test_loader, true_value)因为每一个epoch训练结束后,我们需要测试一下这个网络的性能,所有会在训练函数中频繁调用测试函数,所有测试函数中所有需要的…...

rust特性

特性,也叫特质,英文是trait。 trait是一种特殊的类型,用于抽象某些方法。trait类似于其他编程语言中的接口,但又有所不同。 trait定义了一组方法,其他类型可以各自实现这个trait的方法,从而形成多态。 一、…...

TouchGFX之画布控件

TouchGFX的画布控件,在使用相对较小的存储空间的同时保持高性能,可提供平滑、抗锯齿效果良好的几何图形绘制。 TouchGFX 设计器中可用的画布控件: LineCircleShapeLine Progress圆形进度条 存储空间分配和使用​ 为了生成反锯齿效果良好的…...

STM32F103RCT6学习笔记2:串口通信

今日开始快速掌握这款STM32F103RCT6芯片的环境与编程开发,有关基础知识的部分不会多唠,直接实践与运用!文章贴出代码测试工程与测试效果图: 目录 串口通信实验计划: 串口通信配置代码: 测试效果图&#…...

Opencv-图像噪声(均值滤波、高斯滤波、中值滤波)

图像的噪声 图像的平滑 均值滤波 均值滤波代码实现 import cv2 as cv import numpy as np import matplotlib.pyplot as plt from pylab import mplmpl.rcParams[font.sans-serif] [SimHei]img cv.imread("dog.png")#均值滤波cv.blur(img, (5, 5))将对图像img进行…...

MasterAlign相机参数设置-增益调节

相机参数设置-曝光时间调节操作说明 相机参数的设置对于获取清晰、准确的图像至关重要。曝光时间是其中一个关键参数,它直接影响图像的亮度和清晰度。以下是关于曝光时间调节的详细操作步骤,以帮助您轻松进行设置。 步骤一:登录系统 首先&…...

9月22日,每日信息差

今天是2023年09月22日,以下是为您准备的14条信息差 第一、亚马逊将于2024年初在Prime Video中加入广告。Prime Video内容中的广告将于2024年初在美国、英国、德国和加拿大推出,随后晚些时候在法国、意大利、西班牙、墨西哥和澳大利亚推出 第二、中国移…...

Java版本企业工程项目管理系统源码+spring cloud 系统管理+java 系统设置+二次开发

工程项目各模块及其功能点清单 一、系统管理 1、数据字典:实现对数据字典标签的增删改查操作 2、编码管理:实现对系统编码的增删改查操作 3、用户管理:管理和查看用户角色 4、菜单管理:实现对系统菜单的增删改查操…...

Android studio中如何下载sdk

打开 file -> settings 这个页面, 在要下载的 SDK 前面勾上, 然后点 apply 在 platforms 中就可以看到下载好的 SDK: Android SDK目录结构详细介绍可以参考这篇文章: 51CTO博客- Android SDK目录结构...

STM32单片机中国象棋TFT触摸屏小游戏

实践制作DIY- GC0167-中国象棋 一、功能说明: 基于STM32单片机设计-中国象棋 二、功能介绍: 硬件组成:STM32F103RCT6最小系统2.8寸TFT电阻触摸屏24C02存储器1个按键(悔棋) 游戏规则: 1.有悔棋键&…...

【PHP图片托管】CFimagehost搭建私人图床 - 无需数据库支持

文章目录 1.前言2. CFImagehost网站搭建2.1 CFImagehost下载和安装2.2 CFImagehost网页测试2.3 cpolar的安装和注册 3.本地网页发布3.1 Cpolar临时数据隧道3.2 Cpolar稳定隧道(云端设置)3.3.Cpolar稳定隧道(本地设置) 4.公网访问测…...

CCITT 标准的CRC-16检验算法

/******该文件使用查表法计算CCITT 标准的CRC-16检验码,并附测试代码********/ #include #define CRC_INIT 0xffff //CCITT初始CRC为全1 #define GOOD_CRC 0xf0b8 //校验时计算出的固定结果值 /****下表是常用ccitt 16,生成式1021反转成8408后的查询表格****/ u…...

docker启动mysql服务

创建基础文件 mkdir mysql mkdir -p mysql/data获取默认的my.cnf docker run -name mysql -d -p 3306:3306 mysql:latest docker cp mysql:/etc/my.cnf ./vim mysql/my.cnf # For advice on how to change settings please see # http://dev.mysql.com/doc/refman/8.1/en/se…...

Postman应用——Request数据导入导出

文章目录 导入请求数据导出请求数据导出Collection导出Environments 导出所有请求数据导出请求响应数据 Postman可以导入导出Request和Variable变量配置,可以通过文本方式(JOSN文本)或链接方式进行导入导出。 导入请求数据 可以通过JSON文件…...

十四、MySql的用户管理

文章目录 一、用户管理二、用户(一)用户信息(二)创建用户1.语法:2.案例: (三) 删除用户1.语法:2.示例: (四)修改用户密码1.语法&#…...

01.自动化交易综述

算法交易的概念: 利用自动化平台,执行预先设置的一系列规则完成交易行为。 算法交易的优势 1.历史数据评估 2.执行高效 3.无主观情绪输入 4.可度量评价 5.交易频率 算法交易的劣势 1.成本,成本低难以体现收益 2.技巧 算法交易流程 大前…...

基于SpringBoot的网上超市系统的设计与实现

目录 前言 一、技术栈 二、系统功能介绍 管理员功能实现 用户功能实现 三、核心代码 1、登录模块 2、文件上传模块 3、代码封装 前言 网络技术和计算机技术发展至今,已经拥有了深厚的理论基础,并在现实中进行了充分运用,尤其是基于计…...

国内首家!阿里云 Elasticsearch 8.9 版本释放 AI 搜索新动能

简介: 阿里云作为国内首家上线 Elasticsearch 8.9版本的厂商,在提供 Elasticsearch Relevance Engine™ (ESRE™) 引擎的基础上,提供增强 AI 的最佳实践与 ES 本身的混合搜索能力,为用户带来了更多创新和探索的可能性。 近年来&a…...

uniapp获取一周日期和星期

UniApp可以使用JavaScript中的Date对象来获取当前日期和星期几。以下是一个示例代码,可以获取当前日期和星期几,并输出在一周内的每天早上和晚上: // 获取当前日期和星期 let date new Date(); let weekdays ["Sunday", "M…...

QT之QListWidget的介绍

QListWidget常用成员函数 1、成员函数介绍2、例子显示图片和按钮的例子 1、成员函数介绍 1)QListWidget(QWidget *parent nullptr) 构造函数,创建一个新的QListWidget对象。 2)void addItem(const QString &label) 在列表末尾添加一个项目,项目标…...

接口测试中缓存处理策略

在接口测试中,缓存处理策略是一个关键环节,直接影响测试结果的准确性和可靠性。合理的缓存处理策略能够确保测试环境的一致性,避免因缓存数据导致的测试偏差。以下是接口测试中常见的缓存处理策略及其详细说明: 一、缓存处理的核…...

java_网络服务相关_gateway_nacos_feign区别联系

1. spring-cloud-starter-gateway 作用:作为微服务架构的网关,统一入口,处理所有外部请求。 核心能力: 路由转发(基于路径、服务名等)过滤器(鉴权、限流、日志、Header 处理)支持负…...

2025年能源电力系统与流体力学国际会议 (EPSFD 2025)

2025年能源电力系统与流体力学国际会议(EPSFD 2025)将于本年度在美丽的杭州盛大召开。作为全球能源、电力系统以及流体力学领域的顶级盛会,EPSFD 2025旨在为来自世界各地的科学家、工程师和研究人员提供一个展示最新研究成果、分享实践经验及…...

在 Nginx Stream 层“改写”MQTT ngx_stream_mqtt_filter_module

1、为什么要修改 CONNECT 报文? 多租户隔离:自动为接入设备追加租户前缀,后端按 ClientID 拆分队列。零代码鉴权:将入站用户名替换为 OAuth Access-Token,后端 Broker 统一校验。灰度发布:根据 IP/地理位写…...

项目部署到Linux上时遇到的错误(Redis,MySQL,无法正确连接,地址占用问题)

Redis无法正确连接 在运行jar包时出现了这样的错误 查询得知问题核心在于Redis连接失败,具体原因是客户端发送了密码认证请求,但Redis服务器未设置密码 1.为Redis设置密码(匹配客户端配置) 步骤: 1).修…...

laravel8+vue3.0+element-plus搭建方法

创建 laravel8 项目 composer create-project --prefer-dist laravel/laravel laravel8 8.* 安装 laravel/ui composer require laravel/ui 修改 package.json 文件 "devDependencies": {"vue/compiler-sfc": "^3.0.7","axios": …...

深入浅出深度学习基础:从感知机到全连接神经网络的核心原理与应用

文章目录 前言一、感知机 (Perceptron)1.1 基础介绍1.1.1 感知机是什么?1.1.2 感知机的工作原理 1.2 感知机的简单应用:基本逻辑门1.2.1 逻辑与 (Logic AND)1.2.2 逻辑或 (Logic OR)1.2.3 逻辑与非 (Logic NAND) 1.3 感知机的实现1.3.1 简单实现 (基于阈…...

STM32HAL库USART源代码解析及应用

STM32HAL库USART源代码解析 前言STM32CubeIDE配置串口USART和UART的选择使用模式参数设置GPIO配置DMA配置中断配置硬件流控制使能生成代码解析和使用方法串口初始化__UART_HandleTypeDef结构体浅析HAL库代码实际使用方法使用轮询方式发送使用轮询方式接收使用中断方式发送使用中…...

若依登录用户名和密码加密

/*** 获取公钥:前端用来密码加密* return*/GetMapping("/getPublicKey")public RSAUtil.RSAKeyPair getPublicKey() {return RSAUtil.rsaKeyPair();}新建RSAUti.Java package com.ruoyi.common.utils;import org.apache.commons.codec.binary.Base64; im…...

__VUE_PROD_HYDRATION_MISMATCH_DETAILS__ is not explicitly defined.

这个警告表明您在使用Vue的esm-bundler构建版本时,未明确定义编译时特性标志。以下是详细解释和解决方案: ‌问题原因‌: 该标志是Vue 3.4引入的编译时特性标志,用于控制生产环境下SSR水合不匹配错误的详细报告1使用esm-bundler…...