LeNet-5(fashion-mnist)
文章目录
- 前言
- LeNet
- 模型训练
前言
LeNet是最早发布的卷积神经网络之一。该模型被提出用于识别图像中的手写数字。
LeNet
LeNet-5由以下两个部分组成
- 卷积编码器(2)
- 全连接层(3)
卷积块由一个卷积层、一个sigmoid激活函数和一个平均汇聚层组成。
第一个卷积层有6个输出通道,第二个卷积层有16个输出通道。采用2×2的汇聚操作,且步幅为2.
3个全连接层分别有120,84,10个输出。
此处对原始模型做出部分修改,去除最后一层的高斯激活。
net=nn.Sequential(nn.Conv2d(1,6,kernel_size=5,padding=2),nn.Sigmoid(),nn.AvgPool2d(kernel_size=2,stride=2),nn.Conv2d(6,16,kernel_size=5),nn.Sigmoid(),nn.AvgPool2d(kernel_size=2,stride=2),nn.Flatten(),nn.Linear(16*5*5,120),nn.Sigmoid(),nn.Linear(120,84),nn.Sigmoid(),nn.Linear(84,10))
模型训练
为了加快训练,使用GPU计算测试集上的精度以及训练过程中的计算。
此处采用xavier初始化模型参数以及交叉熵损失函数和小批量梯度下降。
batch_size=256
train_iter,test_iter=data_iter.load_data_fashion_mnist(batch_size)
将数据送入GPU进行计算测试集准确率
def evaluate_accuracy_gpu(net,data_iter,device=None):"""使用GPU计算模型在数据集上的精度"""if isinstance(net,torch.nn.Module):net.eval()if not device:device=next(iter(net.parameters())).device# 正确预测的数量,预测的总数eva = 0.0y_num = 0.0with torch.no_grad():for X,y in data_iter:if isinstance(X,list):X=[x.to(device) for x in X]else:X=X.to(device)y=y.to(device)eva += accuracy(net(X), y)y_num += y.numel()return eva/y_num
训练过程同样将数据送入GPU计算
def train_epoch_gpu(net, train_iter, loss, updater,device):# 训练损失之和,训练准确数之和,样本数train_loss_sum = 0.0train_acc_sum = 0.0num_samples = 0.0# timer = d2l.torch.Timer()for i, (X, y) in enumerate(train_iter):# timer.start()updater.zero_grad()X, y = X.to(device), y.to(device)y_hat = net(X)l = loss(y_hat, y)l.backward()updater.step()with torch.no_grad():train_loss_sum += l * X.shape[0]train_acc_sum += evaluation.accuracy(y_hat, y)num_samples += X.shape[0]# timer.stop()return train_loss_sum/num_samples,train_acc_sum/num_samplesdef train_gpu(net,train_iter,test_iter,num_epochs,lr,device):def init_weights(m):if type(m)==torch.nn.Linear or type(m)==torch.nn.Conv2d:torch.nn.init.xavier_uniform_(m.weight)net.apply(init_weights)net.to(device)print('training on',device)optimizer=torch.optim.SGD(net.parameters(),lr=lr)loss=torch.nn.CrossEntropyLoss()# num_batches=len(train_iter)tr_l=[]tr_a=[]te_a=[]for epoch in range(num_epochs):net.train()train_metric=train_epoch_gpu(net,train_iter,loss,optimizer,device)test_accuracy = evaluation.evaluate_accuracy_gpu(net, test_iter)train_loss, train_acc = train_metrictrain_loss = train_loss.cpu().detach().numpy()tr_l.append(train_loss)tr_a.append(train_acc)te_a.append(test_accuracy)print(f'epoch: {epoch + 1}, train_loss: {train_loss}, train_acc: {train_acc}, test_acc:{test_accuracy}')x = torch.arange(num_epochs)plt.plot((x + 1), tr_l, '-', label='train_loss')plt.plot(x + 1, tr_a, '--', label='train_acc')plt.plot(x + 1, te_a, '-.', label='test_acc')plt.legend()plt.show()print(f'on {str(device)}')
lr,num_epochs=0.9,10
Train.train_gpu(net,train_iter,test_iter,num_epochs,lr,device='cuda')


相关文章:
LeNet-5(fashion-mnist)
文章目录 前言LeNet模型训练 前言 LeNet是最早发布的卷积神经网络之一。该模型被提出用于识别图像中的手写数字。 LeNet LeNet-5由以下两个部分组成 卷积编码器(2)全连接层(3) 卷积块由一个卷积层、一个sigmoid激活函数和一个…...
Unity中URP下开启和使用深度图
文章目录 前言一、在Unity中打开URP下的深度图二、在Shader中开启深度图1、使用不透明渲染队列才可以使用深度图2、半透明渲染队列深度图就会关闭 三、URP深度图 和 BRP深度图的区别四、在Shader中,使用深度图1、定义纹理和采样器2、在片元着色器对深度图采样并且输…...
类似东郊到家上门预约系统需要具备哪些功能,预约系统应该怎么做
随着上门服务需求的持续增长,各类APP小程序应运而生。吸引了无数商家投资者,纷纷想要开发一款类似于"东郊到家"这样的上门服务软件。要想成功,这样的软件需具备以下核心功能: 1. 快速注册与登录:用户能通过手…...
鸿蒙APP和Android的区别
鸿蒙(HarmonyOS)和Android是两个不同的操作系统,它们有一些区别,包括架构、开发者支持、应用生态和一些设计理念。以下是鸿蒙APP和Android APP之间的一些主要区别,希望对大家有所帮助。北京木奇移动技术有限公司&#…...
给Flutter + FireBase 增加 badge 徽章,App启动器 通知红点。
在此之前需要配置好 firebase 在flutter 在项目中。(已经配置好的可以忽略此提示) Firebase 配置教程:flutter firebase 云消息通知教程 (android-安卓、ios-苹果)_flutter firebase_messaging ios环境配置-CSDN博客 由于firebase 提供的消息…...
2024年中国杭州|网络安全技能大赛(CTF)正式开启竞赛报名
前言 一、CTF简介 CTF(Capture The Flag)中文一般译作夺旗赛,在网络安全领域中指的是网络安全技术人员之间进行技术竞技的一种比赛形式。CTF起源于1996年DEFCON全球黑客大会,以代替之前黑客们通过互相发起真实攻击进行技术比拼的…...
112.Qt中的窗口类
我们在通过Qt向导窗口基于窗口的应用程序的项目过程中倒数第二步让我们选择跟随项目创建的第一个窗口的基类, 下拉菜单中有三个选项, 分别为: QMainWindow、QDialog、QWidget如下图: 常用的窗口类有3个 在创建Qt窗口的时候, 需要让自己的窗口类继承上述三个窗口类的…...
如何设置电脑桌面提醒,电脑笔记软件哪个好?
对于大多数上班族来说,每天要完成的待办事项实在太多了,如果不能及时去处理,很容易因为各种因素导致忘记,从而给自己带来不少麻烦。所以,我们往往会借助一些提醒类的软件将各项任务逐一记录下来,然后设置上…...
C# HttpClient Get Post简单封装
文章目录 前言封装好的代码测试接口测试代码 前言 微软官方有Get和Post请求,我把他简单化处理一下 封装好的代码 public class MyHttpHelper{private string baseUrl;/// <summary>/// 基础Api/// </summary>public string BaseUrl{get{return baseUr…...
创建网格(Grid/GridItem)
目录 1、概述 2、布局与约束 3、设置排列方式 3.1设置行列数量与占比 3.2、设置子组件所占行列数 3.3、设置主轴方向 3.4、在网格布局中显示数据 3.5、设置行列间距 4、构建可滚动的网格布局 5、实现简单的日历功能 6、性能优化 1、概述 网格布局是由“行”和“列”分…...
思科路由器忘记密码怎么重置
断电重启路由器,在开机过程中按下CtrlPause/break,或者只按下Pause/break(没有测试),在PT(Cisco Packet Tracert)中则需要按CtrlC。路由器会进入rommon >模式。 切换到0x2142模式࿰…...
JVM基础(2)——JVM内存模型
一、简介 JVM会加载类到内存中,所以 JVM 中必然会有一块内存区域来存放我们写的那些类。Java中有类对象、普通对象、本地变量、方法信息等等各种对象信息,所以JVM会对内存区域进行划分: JDK1.8及以后,上图中的方法区变成了Metasp…...
使用 Process Explorer 和 Windbg 排查软件线程堵塞问题
目录 1、问题说明 2、线程堵塞的可能原因分析 3、使用Windbg和Process Explorer确定线程中发生了死循环 4、根据Windbg中显示的函数调用堆栈去查看源码,找到问题 4.1、在Windbg定位发生死循环的函数的方法 4.2、在Windbg中查看变量的值去辅助分析 4.3、是循环…...
做科技类的展台3d模型用什么材质比较好---模大狮模型网
对于科技类展台3D模型,以下是几种常用的材质选择: 金属材质:金属材质常用于科技展台的现代感设计,如不锈钢、铝合金或镀铬材质。金属材质可以赋予展台一个科技感和高档感,同时还可以反射光线,增加模型的真实…...
EasyExcel简单实例(未完待续)
EasyExcel简单实例 准备工作场景一:读取 Student 表需求1:简单读取需求2:读取到异常信息时不中断需求3:读取所有的sheet工作表需求4:读取指定的sheet工作表需求5:从指定的行开始读取 场景二:写入…...
ROS2学习笔记一:安装及测试
目录 前言 1 ROS2安装与卸载 1.1 安装虚拟机 1.2 ROS2 humble安装 2 ROS2测试 2.1 topic测试 2.2 小海龟测试 2.3 RQT可视化 2.4 占用空间 前言 ROS2的前身是ROS,ROS即机器人操作系统(Robot Operating System),ROS为了“提高机器人…...
Xcode14.3.1真机调试iOS17的方法
Hello,大家好我是咕噜铁蛋!Xcode 是苹果官方开发工具,它提供了完整的开发环境和工具集,支持开发 iOS、macOS、watchOS 和 tvOS 应用程序。对于 iOS 开发者来说,Xcode 是必备的工具之一。而随着 iOS 系统的不断更新和升…...
主流大语言模型从预训练到微调的技术原理
引言 本文设计的内容主要包含以下几个方面: 比较 LLaMA、ChatGLM、Falcon 等大语言模型的细节:tokenizer、位置编码、Layer Normalization、激活函数等。大语言模型的分布式训练技术:数据并行、张量模型并行、流水线并行、3D 并行、零冗余优…...
Linux中vim查看文件某内容
一、编辑文件命令 [rootyinheqilin ~]# vim test.txt 1,在编辑的文件中连续按2次键盘的【g】键,光标会移动到文档开头第一行 2,输入一个大写 G,光标会跳转到文件的最后一行第一列(末行) 二、查看文件内容命令 gre…...
阿里云提示服务器ip暴露该怎么办?-速盾网络(sudun)
当阿里云提示服务器IP暴露的时候,这意味着您的服务器可能面临安全风险,因为黑客可以通过知道服务器的IP地址来尝试入侵您的系统。在这种情况下,您应该立即采取措施来保护您的服务器和数据。以下是一些建议: 更改服务器IP地址&…...
谷歌浏览器插件
项目中有时候会用到插件 sync-cookie-extension1.0.0:开发环境同步测试 cookie 至 localhost,便于本地请求服务携带 cookie 参考地址:https://juejin.cn/post/7139354571712757767 里面有源码下载下来,加在到扩展即可使用FeHelp…...
C++实现分布式网络通信框架RPC(3)--rpc调用端
目录 一、前言 二、UserServiceRpc_Stub 三、 CallMethod方法的重写 头文件 实现 四、rpc调用端的调用 实现 五、 google::protobuf::RpcController *controller 头文件 实现 六、总结 一、前言 在前边的文章中,我们已经大致实现了rpc服务端的各项功能代…...
SciencePlots——绘制论文中的图片
文章目录 安装一、风格二、1 资源 安装 # 安装最新版 pip install githttps://github.com/garrettj403/SciencePlots.git# 安装稳定版 pip install SciencePlots一、风格 简单好用的深度学习论文绘图专用工具包–Science Plot 二、 1 资源 论文绘图神器来了:一行…...
Admin.Net中的消息通信SignalR解释
定义集线器接口 IOnlineUserHub public interface IOnlineUserHub {/// 在线用户列表Task OnlineUserList(OnlineUserList context);/// 强制下线Task ForceOffline(object context);/// 发布站内消息Task PublicNotice(SysNotice context);/// 接收消息Task ReceiveMessage(…...
pam_env.so模块配置解析
在PAM(Pluggable Authentication Modules)配置中, /etc/pam.d/su 文件相关配置含义如下: 配置解析 auth required pam_env.so1. 字段分解 字段值说明模块类型auth认证类模块,负责验证用户身份&am…...
《通信之道——从微积分到 5G》读书总结
第1章 绪 论 1.1 这是一本什么样的书 通信技术,说到底就是数学。 那些最基础、最本质的部分。 1.2 什么是通信 通信 发送方 接收方 承载信息的信号 解调出其中承载的信息 信息在发送方那里被加工成信号(调制) 把信息从信号中抽取出来&am…...
Unity UGUI Button事件流程
场景结构 测试代码 public class TestBtn : MonoBehaviour {void Start(){var btn GetComponent<Button>();btn.onClick.AddListener(OnClick);}private void OnClick(){Debug.Log("666");}}当添加事件时 // 实例化一个ButtonClickedEvent的事件 [Formerl…...
日常一水C
多态 言简意赅:就是一个对象面对同一事件时做出的不同反应 而之前的继承中说过,当子类和父类的函数名相同时,会隐藏父类的同名函数转而调用子类的同名函数,如果要调用父类的同名函数,那么就需要对父类进行引用&#…...
永磁同步电机无速度算法--基于卡尔曼滤波器的滑模观测器
一、原理介绍 传统滑模观测器采用如下结构: 传统SMO中LPF会带来相位延迟和幅值衰减,并且需要额外的相位补偿。 采用扩展卡尔曼滤波器代替常用低通滤波器(LPF),可以去除高次谐波,并且不用相位补偿就可以获得一个误差较小的转子位…...
MyBatis中关于缓存的理解
MyBatis缓存 MyBatis系统当中默认定义两级缓存:一级缓存、二级缓存 默认情况下,只有一级缓存开启(sqlSession级别的缓存)二级缓存需要手动开启配置,需要局域namespace级别的缓存 一级缓存(本地缓存&#…...
