LeNet-5(fashion-mnist)
文章目录
- 前言
- LeNet
- 模型训练
前言
LeNet是最早发布的卷积神经网络之一。该模型被提出用于识别图像中的手写数字。
LeNet
LeNet-5由以下两个部分组成
- 卷积编码器(2)
- 全连接层(3)
卷积块由一个卷积层、一个sigmoid激活函数和一个平均汇聚层组成。
第一个卷积层有6个输出通道,第二个卷积层有16个输出通道。采用2×2的汇聚操作,且步幅为2.
3个全连接层分别有120,84,10个输出。
此处对原始模型做出部分修改,去除最后一层的高斯激活。
net=nn.Sequential(nn.Conv2d(1,6,kernel_size=5,padding=2),nn.Sigmoid(),nn.AvgPool2d(kernel_size=2,stride=2),nn.Conv2d(6,16,kernel_size=5),nn.Sigmoid(),nn.AvgPool2d(kernel_size=2,stride=2),nn.Flatten(),nn.Linear(16*5*5,120),nn.Sigmoid(),nn.Linear(120,84),nn.Sigmoid(),nn.Linear(84,10))
模型训练
为了加快训练,使用GPU计算测试集上的精度以及训练过程中的计算。
此处采用xavier初始化模型参数以及交叉熵损失函数和小批量梯度下降。
batch_size=256
train_iter,test_iter=data_iter.load_data_fashion_mnist(batch_size)
将数据送入GPU进行计算测试集准确率
def evaluate_accuracy_gpu(net,data_iter,device=None):"""使用GPU计算模型在数据集上的精度"""if isinstance(net,torch.nn.Module):net.eval()if not device:device=next(iter(net.parameters())).device# 正确预测的数量,预测的总数eva = 0.0y_num = 0.0with torch.no_grad():for X,y in data_iter:if isinstance(X,list):X=[x.to(device) for x in X]else:X=X.to(device)y=y.to(device)eva += accuracy(net(X), y)y_num += y.numel()return eva/y_num
训练过程同样将数据送入GPU计算
def train_epoch_gpu(net, train_iter, loss, updater,device):# 训练损失之和,训练准确数之和,样本数train_loss_sum = 0.0train_acc_sum = 0.0num_samples = 0.0# timer = d2l.torch.Timer()for i, (X, y) in enumerate(train_iter):# timer.start()updater.zero_grad()X, y = X.to(device), y.to(device)y_hat = net(X)l = loss(y_hat, y)l.backward()updater.step()with torch.no_grad():train_loss_sum += l * X.shape[0]train_acc_sum += evaluation.accuracy(y_hat, y)num_samples += X.shape[0]# timer.stop()return train_loss_sum/num_samples,train_acc_sum/num_samplesdef train_gpu(net,train_iter,test_iter,num_epochs,lr,device):def init_weights(m):if type(m)==torch.nn.Linear or type(m)==torch.nn.Conv2d:torch.nn.init.xavier_uniform_(m.weight)net.apply(init_weights)net.to(device)print('training on',device)optimizer=torch.optim.SGD(net.parameters(),lr=lr)loss=torch.nn.CrossEntropyLoss()# num_batches=len(train_iter)tr_l=[]tr_a=[]te_a=[]for epoch in range(num_epochs):net.train()train_metric=train_epoch_gpu(net,train_iter,loss,optimizer,device)test_accuracy = evaluation.evaluate_accuracy_gpu(net, test_iter)train_loss, train_acc = train_metrictrain_loss = train_loss.cpu().detach().numpy()tr_l.append(train_loss)tr_a.append(train_acc)te_a.append(test_accuracy)print(f'epoch: {epoch + 1}, train_loss: {train_loss}, train_acc: {train_acc}, test_acc:{test_accuracy}')x = torch.arange(num_epochs)plt.plot((x + 1), tr_l, '-', label='train_loss')plt.plot(x + 1, tr_a, '--', label='train_acc')plt.plot(x + 1, te_a, '-.', label='test_acc')plt.legend()plt.show()print(f'on {str(device)}')
lr,num_epochs=0.9,10
Train.train_gpu(net,train_iter,test_iter,num_epochs,lr,device='cuda')


相关文章:
LeNet-5(fashion-mnist)
文章目录 前言LeNet模型训练 前言 LeNet是最早发布的卷积神经网络之一。该模型被提出用于识别图像中的手写数字。 LeNet LeNet-5由以下两个部分组成 卷积编码器(2)全连接层(3) 卷积块由一个卷积层、一个sigmoid激活函数和一个…...
Unity中URP下开启和使用深度图
文章目录 前言一、在Unity中打开URP下的深度图二、在Shader中开启深度图1、使用不透明渲染队列才可以使用深度图2、半透明渲染队列深度图就会关闭 三、URP深度图 和 BRP深度图的区别四、在Shader中,使用深度图1、定义纹理和采样器2、在片元着色器对深度图采样并且输…...
类似东郊到家上门预约系统需要具备哪些功能,预约系统应该怎么做
随着上门服务需求的持续增长,各类APP小程序应运而生。吸引了无数商家投资者,纷纷想要开发一款类似于"东郊到家"这样的上门服务软件。要想成功,这样的软件需具备以下核心功能: 1. 快速注册与登录:用户能通过手…...
鸿蒙APP和Android的区别
鸿蒙(HarmonyOS)和Android是两个不同的操作系统,它们有一些区别,包括架构、开发者支持、应用生态和一些设计理念。以下是鸿蒙APP和Android APP之间的一些主要区别,希望对大家有所帮助。北京木奇移动技术有限公司&#…...
给Flutter + FireBase 增加 badge 徽章,App启动器 通知红点。
在此之前需要配置好 firebase 在flutter 在项目中。(已经配置好的可以忽略此提示) Firebase 配置教程:flutter firebase 云消息通知教程 (android-安卓、ios-苹果)_flutter firebase_messaging ios环境配置-CSDN博客 由于firebase 提供的消息…...
2024年中国杭州|网络安全技能大赛(CTF)正式开启竞赛报名
前言 一、CTF简介 CTF(Capture The Flag)中文一般译作夺旗赛,在网络安全领域中指的是网络安全技术人员之间进行技术竞技的一种比赛形式。CTF起源于1996年DEFCON全球黑客大会,以代替之前黑客们通过互相发起真实攻击进行技术比拼的…...
112.Qt中的窗口类
我们在通过Qt向导窗口基于窗口的应用程序的项目过程中倒数第二步让我们选择跟随项目创建的第一个窗口的基类, 下拉菜单中有三个选项, 分别为: QMainWindow、QDialog、QWidget如下图: 常用的窗口类有3个 在创建Qt窗口的时候, 需要让自己的窗口类继承上述三个窗口类的…...
如何设置电脑桌面提醒,电脑笔记软件哪个好?
对于大多数上班族来说,每天要完成的待办事项实在太多了,如果不能及时去处理,很容易因为各种因素导致忘记,从而给自己带来不少麻烦。所以,我们往往会借助一些提醒类的软件将各项任务逐一记录下来,然后设置上…...
C# HttpClient Get Post简单封装
文章目录 前言封装好的代码测试接口测试代码 前言 微软官方有Get和Post请求,我把他简单化处理一下 封装好的代码 public class MyHttpHelper{private string baseUrl;/// <summary>/// 基础Api/// </summary>public string BaseUrl{get{return baseUr…...
创建网格(Grid/GridItem)
目录 1、概述 2、布局与约束 3、设置排列方式 3.1设置行列数量与占比 3.2、设置子组件所占行列数 3.3、设置主轴方向 3.4、在网格布局中显示数据 3.5、设置行列间距 4、构建可滚动的网格布局 5、实现简单的日历功能 6、性能优化 1、概述 网格布局是由“行”和“列”分…...
思科路由器忘记密码怎么重置
断电重启路由器,在开机过程中按下CtrlPause/break,或者只按下Pause/break(没有测试),在PT(Cisco Packet Tracert)中则需要按CtrlC。路由器会进入rommon >模式。 切换到0x2142模式࿰…...
JVM基础(2)——JVM内存模型
一、简介 JVM会加载类到内存中,所以 JVM 中必然会有一块内存区域来存放我们写的那些类。Java中有类对象、普通对象、本地变量、方法信息等等各种对象信息,所以JVM会对内存区域进行划分: JDK1.8及以后,上图中的方法区变成了Metasp…...
使用 Process Explorer 和 Windbg 排查软件线程堵塞问题
目录 1、问题说明 2、线程堵塞的可能原因分析 3、使用Windbg和Process Explorer确定线程中发生了死循环 4、根据Windbg中显示的函数调用堆栈去查看源码,找到问题 4.1、在Windbg定位发生死循环的函数的方法 4.2、在Windbg中查看变量的值去辅助分析 4.3、是循环…...
做科技类的展台3d模型用什么材质比较好---模大狮模型网
对于科技类展台3D模型,以下是几种常用的材质选择: 金属材质:金属材质常用于科技展台的现代感设计,如不锈钢、铝合金或镀铬材质。金属材质可以赋予展台一个科技感和高档感,同时还可以反射光线,增加模型的真实…...
EasyExcel简单实例(未完待续)
EasyExcel简单实例 准备工作场景一:读取 Student 表需求1:简单读取需求2:读取到异常信息时不中断需求3:读取所有的sheet工作表需求4:读取指定的sheet工作表需求5:从指定的行开始读取 场景二:写入…...
ROS2学习笔记一:安装及测试
目录 前言 1 ROS2安装与卸载 1.1 安装虚拟机 1.2 ROS2 humble安装 2 ROS2测试 2.1 topic测试 2.2 小海龟测试 2.3 RQT可视化 2.4 占用空间 前言 ROS2的前身是ROS,ROS即机器人操作系统(Robot Operating System),ROS为了“提高机器人…...
Xcode14.3.1真机调试iOS17的方法
Hello,大家好我是咕噜铁蛋!Xcode 是苹果官方开发工具,它提供了完整的开发环境和工具集,支持开发 iOS、macOS、watchOS 和 tvOS 应用程序。对于 iOS 开发者来说,Xcode 是必备的工具之一。而随着 iOS 系统的不断更新和升…...
主流大语言模型从预训练到微调的技术原理
引言 本文设计的内容主要包含以下几个方面: 比较 LLaMA、ChatGLM、Falcon 等大语言模型的细节:tokenizer、位置编码、Layer Normalization、激活函数等。大语言模型的分布式训练技术:数据并行、张量模型并行、流水线并行、3D 并行、零冗余优…...
Linux中vim查看文件某内容
一、编辑文件命令 [rootyinheqilin ~]# vim test.txt 1,在编辑的文件中连续按2次键盘的【g】键,光标会移动到文档开头第一行 2,输入一个大写 G,光标会跳转到文件的最后一行第一列(末行) 二、查看文件内容命令 gre…...
阿里云提示服务器ip暴露该怎么办?-速盾网络(sudun)
当阿里云提示服务器IP暴露的时候,这意味着您的服务器可能面临安全风险,因为黑客可以通过知道服务器的IP地址来尝试入侵您的系统。在这种情况下,您应该立即采取措施来保护您的服务器和数据。以下是一些建议: 更改服务器IP地址&…...
大数据学习栈记——Neo4j的安装与使用
本文介绍图数据库Neofj的安装与使用,操作系统:Ubuntu24.04,Neofj版本:2025.04.0。 Apt安装 Neofj可以进行官网安装:Neo4j Deployment Center - Graph Database & Analytics 我这里安装是添加软件源的方法 最新版…...
python打卡day49
知识点回顾: 通道注意力模块复习空间注意力模块CBAM的定义 作业:尝试对今天的模型检查参数数目,并用tensorboard查看训练过程 import torch import torch.nn as nn# 定义通道注意力 class ChannelAttention(nn.Module):def __init__(self,…...
在HarmonyOS ArkTS ArkUI-X 5.0及以上版本中,手势开发全攻略:
在 HarmonyOS 应用开发中,手势交互是连接用户与设备的核心纽带。ArkTS 框架提供了丰富的手势处理能力,既支持点击、长按、拖拽等基础单一手势的精细控制,也能通过多种绑定策略解决父子组件的手势竞争问题。本文将结合官方开发文档,…...
23-Oracle 23 ai 区块链表(Blockchain Table)
小伙伴有没有在金融强合规的领域中遇见,必须要保持数据不可变,管理员都无法修改和留痕的要求。比如医疗的电子病历中,影像检查检验结果不可篡改行的,药品追溯过程中数据只可插入无法删除的特性需求;登录日志、修改日志…...
Linux简单的操作
ls ls 查看当前目录 ll 查看详细内容 ls -a 查看所有的内容 ls --help 查看方法文档 pwd pwd 查看当前路径 cd cd 转路径 cd .. 转上一级路径 cd 名 转换路径 …...
Android15默认授权浮窗权限
我们经常有那种需求,客户需要定制的apk集成在ROM中,并且默认授予其【显示在其他应用的上层】权限,也就是我们常说的浮窗权限,那么我们就可以通过以下方法在wms、ams等系统服务的systemReady()方法中调用即可实现预置应用默认授权浮…...
在鸿蒙HarmonyOS 5中使用DevEco Studio实现录音机应用
1. 项目配置与权限设置 1.1 配置module.json5 {"module": {"requestPermissions": [{"name": "ohos.permission.MICROPHONE","reason": "录音需要麦克风权限"},{"name": "ohos.permission.WRITE…...
自然语言处理——循环神经网络
自然语言处理——循环神经网络 循环神经网络应用到基于机器学习的自然语言处理任务序列到类别同步的序列到序列模式异步的序列到序列模式 参数学习和长程依赖问题基于门控的循环神经网络门控循环单元(GRU)长短期记忆神经网络(LSTM)…...
MySQL账号权限管理指南:安全创建账户与精细授权技巧
在MySQL数据库管理中,合理创建用户账号并分配精确权限是保障数据安全的核心环节。直接使用root账号进行所有操作不仅危险且难以审计操作行为。今天我们来全面解析MySQL账号创建与权限分配的专业方法。 一、为何需要创建独立账号? 最小权限原则…...
深度学习水论文:mamba+图像增强
🧀当前视觉领域对高效长序列建模需求激增,对Mamba图像增强这方向的研究自然也逐渐火热。原因在于其高效长程建模,以及动态计算优势,在图像质量提升和细节恢复方面有难以替代的作用。 🧀因此短时间内,就有不…...
