当前位置: 首页 > news >正文

基于CNN的FashionMNIST数据集识别3——模型验证

源码

import torch
import torch.utils.data as Data
from torchvision import transforms
from torchvision.datasets import FashionMNIST
from model import LeNetdef test_data_process():test_data = FashionMNIST(root='./data',train=False,transform=transforms.Compose([transforms.Resize(size=28), transforms.ToTensor()]),download=True)test_dataloader = Data.DataLoader(dataset=test_data,batch_size=1,shuffle=True,num_workers=0)return test_dataloaderdef test_model_process(model, test_dataloader):# 设定测试所用到的设备,有GPU用GPU没有GPU用CPUdevice = "cuda" if torch.cuda.is_available() else 'cpu'# 讲模型放入到训练设备中model = model.to(device)# 初始化参数test_corrects = 0.0test_num = 0# 只进行前向传播计算,不计算梯度,从而节省内存,加快运行速度with torch.no_grad():for test_data_x, test_data_y in test_dataloader:# 将特征放入到测试设备中test_data_x = test_data_x.to(device)# 将标签放入到测试设备中test_data_y = test_data_y.to(device)# 设置模型为评估模式model.eval()# 前向传播过程,输入为测试数据集,输出为对每个样本的预测值output= model(test_data_x)# 查找每一行中最大值对应的行标pre_lab = torch.argmax(output, dim=1)# 如果预测正确,则准确度test_corrects加1test_corrects += torch.sum(pre_lab == test_data_y.data)# 将所有的测试样本进行累加test_num += test_data_x.size(0)# 计算测试准确率test_acc = test_corrects.double().item() / test_numprint("测试的准确率为:", test_acc)if __name__=="__main__":# 加载模型model = LeNet()model.load_state_dict(torch.load('best_model.pth'))# 加载测试数据test_dataloader = test_data_process()# 加载模型测试的函数test_model_process(model, test_dataloader)

源码讲解

当模型训练完毕后,我们得到的是一组最优的参数配置。

最后要做的就是验证这组参数的表现。

数据准备

def test_data_process():test_data = FashionMNIST(root='./data',train=False,transform=transforms.Compose([transforms.Resize(size=28), transforms.ToTensor()]),download=True)test_dataloader = Data.DataLoader(dataset=test_data,batch_size=1,shuffle=True,num_workers=0)return test_dataloader

测试数据的准备和训练数据的准备有明显不同:

  1. 测试数据集是将MINIST里所有的数据当做验证集。并且训练模式设置为false。
  2. dataloader里面的batch大小设置为1,也就是说对每个样本进行验证,不再存在“分批”的概念。

循环验证

    # 只进行前向传播计算,不计算梯度,从而节省内存,加快运行速度with torch.no_grad():for test_data_x, test_data_y in test_dataloader:# 将特征放入到测试设备中test_data_x = test_data_x.to(device)# 将标签放入到测试设备中test_data_y = test_data_y.to(device)# 设置模型为评估模式model.eval()# 前向传播过程,输入为测试数据集,输出为对每个样本的预测值output= model(test_data_x)# 查找每一行中最大值对应的行标pre_lab = torch.argmax(output, dim=1)# 如果预测正确,则准确度test_corrects加1test_corrects += torch.sum(pre_lab == test_data_y.data)# 将所有的测试样本进行累加test_num += test_data_x.size(0)

和之前训练模型时的验证逻辑基本相同。只进行前向传播,将预测正确的样本个数进行累加。

代码运行

if __name__=="__main__":# 加载模型model = LeNet()model.load_state_dict(torch.load('best_model.pth'))# 加载测试数据test_dataloader = test_data_process()# 加载模型测试的函数test_model_process(model, test_dataloader)

在代码运行时,需要将参数载入到模型里,再进行验证。

相关文章:

基于CNN的FashionMNIST数据集识别3——模型验证

源码 import torch import torch.utils.data as Data from torchvision import transforms from torchvision.datasets import FashionMNIST from model import LeNetdef test_data_process():test_data FashionMNIST(root./data,trainFalse,transformtransforms.Compose([tr…...

go channel 的用法和核心原理、使用场景

一、Channel 的核心用法 1. 基本操作 // 创建无缓冲 Channel&#xff08;同步通信&#xff09; ch : make(chan int) // 创建有缓冲 Channel&#xff08;容量为5&#xff0c;异步通信&#xff09; bufferedCh : make(chan int, 5) // 发送数据到 Channel ch <- 42 // 从…...

pyside6学习专栏(七):自定义QTableWidget的扩展子类QTableWidgetEx

PySide6界面编程中较常用的控件还有QTableWidget表格控件&#xff0c;用来将加载的数据在表格中显示出来&#xff0c;下面继承QTableWidget编写其扩展子类QTableWidgetEx,来实现用单元格来显示除数据文字外&#xff0c;还可以对表格的单元格的文字颜色、背景底色进行设置&#…...

Mybatis常用动态 SQL 相关标签

1. <if> 用于条件判断&#xff0c;当满足条件时执行对应的 SQL 片段。 示例: <select id"findUser" resultType"User">SELECT * FROM usersWHERE 11<if test"name ! null and name ! ">AND name #{name}</if><if…...

AWQ和GPTQ量化的区别

一、前言 本地化部署deepseek时发现&#xff0c;如果是量化版的deepseek&#xff0c;会节约很多的内容&#xff0c;然后一般有两种量化技术&#xff0c;那么这两种量化技术有什么区别呢&#xff1f; 二、量化技术对比 在模型量化领域&#xff0c;AWQ 和 GPTQ 是两种不同的量…...

ESP32S3:解决RWDT无法触发中断问题,二次开发者怎么才能使用内部RTC看门狗中断RWDT呢?

目录 基于ESP32S3:解决RWDT无法触发中断问题引言解决方案1. 查看报错日志2. 分析报错及一步一步找到解决方法3.小结我的源码基于ESP32S3:解决RWDT无法触发中断问题 引言 在嵌入式系统中,RWDT(看门狗定时器)是确保系统稳定性的重要组件。然而,在某些情况下,RWDT可能无法…...

基于SpringBoot的民宿管理系统的设计与实现(源码+SQL脚本+LW+部署讲解等)

专注于大学生项目实战开发,讲解,毕业答疑辅导&#xff0c;欢迎高校老师/同行前辈交流合作✌。 技术范围&#xff1a;SpringBoot、Vue、SSM、HLMT、小程序、Jsp、PHP、Nodejs、Python、爬虫、数据可视化、安卓app、大数据、物联网、机器学习等设计与开发。 主要内容&#xff1a;…...

go 日志框架

内置log import ("log""os" )func main() {// 设置loglog.SetFlags(log.Llongfile | log.Lmicroseconds | log.Ldate)// 自定义日志前缀log.SetPrefix("[pprof]")log.Println("main ..")// 如果用format就用PrintF&#xff0c;而不是…...

如何在 PDF 文件中嵌入自定义数据

由于 PDF 文件格式功能强大且灵活&#xff0c;它经常被用于内部工作流程。有时候&#xff0c;将自定义数据嵌入 PDF 文件本身会非常有用。通常&#xff0c;这些信息会被大多数工具忽略&#xff0c;因此 PDF 仍然可以作为普通 PDF 文件正常使用。 以下是一些实现方法&#xff1…...

计算机毕业设计SpringBoot+Vue.js服装商城 服装购物系统(源码+LW文档+PPT+讲解+开题报告)

温馨提示&#xff1a;文末有 CSDN 平台官方提供的学长联系方式的名片&#xff01; 温馨提示&#xff1a;文末有 CSDN 平台官方提供的学长联系方式的名片&#xff01; 温馨提示&#xff1a;文末有 CSDN 平台官方提供的学长联系方式的名片&#xff01; 作者简介&#xff1a;Java领…...

22.回溯算法4

递增子序列 这里不能排序&#xff0c;因为数组的顺序是对结果有影响的&#xff0c;所以只能通过used数组来去重 class Solution { public:vector<int> path;vector<vector<int>> res;void backtracking(vector<int>& nums,int start){if(path.si…...

linux -对文件描述符的操作dup、fcntl有五种

dup #include<unistd.h> int dup(int oldfd);作用&#xff1a;复制一个新的文件描述符fd 3, int fd1 dup(fd);f指向的是a.txt,fd1指向的也是a.txt从空闲的文件描述符表中找一个最小的作为新的拷贝的文件描述符返回&#xff1a;成功返回新的文件描述符&#xff0c;失败…...

技术解析 | 适用于TeamCity的Unreal Engine支持插件,提升游戏构建效率

龙智是JetBrains授权合作伙伴、Perforce授权合作伙伴&#xff0c;为您提供TeamCity、Perforce Helix Core等热门的游戏开发工具及一站式服务 TeamCity 是游戏开发的热门选择&#xff0c;大家选择它的原因包括支持 Perforce、可以进行本地安装&#xff0c;并提供了多种配置选项。…...

Ubuntu22.04 - brpc的安装和使用

目录 介绍安装使用 介绍 brpc 是用 c语言编写的工业级 RPC 框架&#xff0c;常用于搜索、存储、机器学习、广告、推荐等高性能系统 安装 先安装依赖 apt-get install -y git g make libssl-dev libprotobuf-dev libprotoc-dev protobuf-compiler libleveldb-dev libgflags-d…...

网络运维学习笔记 018 HCIA-Datacom综合实验02

文章目录 综合实验2sw3&#xff1a;sw4&#xff1a;gw&#xff1a;core1&#xff08;sw1&#xff09;&#xff1a;core2&#xff08;sw2&#xff09;&#xff1a;ISP 综合实验2 sw3&#xff1a; vlan 2 stp mode stp int e0/0/1 port link-type trunk port trunk allow-pass v…...

Vulhub靶机 Apache Druid(CVE-2021-25646)(渗透测试详解)

一、开启vulhub环境 docker-compose up -d 启动 docker ps 查看开放的端口 1、漏洞范围 在Druid0.20.0及更低版本中 二、访问靶机IP 8888端口 1、点击Load data进入新界面后&#xff0c;再点击local disk按钮。 2、进入新界面后&#xff0c;在标红框的Base directory栏写上…...

VSCode配置自动生成头文件

一、配置步骤&#xff1a; 1.打开命令面板&#xff08;CtrlShiftp&#xff09;&#xff1a; 2.输入snippets 选择配置代码片段 3. 选择新建全局代码片段 输入文件名,比如header_cpp(随便定义)&#xff0c;然后点击键盘回车按钮&#xff0c;得到下面这个文件。 增加配置文…...

Xcode如何高效的一键重命名某个关键字

1.选中某个需要修改的关键字&#xff1b; 2.右击&#xff0c;选择Refactor->Rename… 然后就会出现如下界面&#xff1a; 此时就可以一键重命名了。 还可以设置快捷键。 1.打开Settings 2.找到Key Bindings 3.搜索rename 4.出现三个&#xff0c;点击一个地方设置后其…...

React 高阶组件的优缺点

React 高阶组件的优缺点 优点 1. 代码复用性高 公共逻辑封装&#xff1a;当多个组件需要实现相同的功能或逻辑时&#xff0c;高阶组件可以将这些逻辑封装起来&#xff0c;避免代码重复。例如&#xff0c;多个组件都需要在挂载时进行数据获取操作&#xff0c;就可以创建一个数…...

(五)趣学设计模式 之 建造者模式!

目录 一、 啥是建造者模式&#xff1f;二、 为什么要用建造者模式&#xff1f;三、 建造者模式怎么实现&#xff1f;四、 建造者模式的应用场景五、 建造者模式的优点和缺点六、 总结 &#x1f31f;我的其他文章也讲解的比较有趣&#x1f601;&#xff0c;如果喜欢博主的讲解方…...

浅谈 React Hooks

React Hooks 是 React 16.8 引入的一组 API&#xff0c;用于在函数组件中使用 state 和其他 React 特性&#xff08;例如生命周期方法、context 等&#xff09;。Hooks 通过简洁的函数接口&#xff0c;解决了状态与 UI 的高度解耦&#xff0c;通过函数式编程范式实现更灵活 Rea…...

基于距离变化能量开销动态调整的WSN低功耗拓扑控制开销算法matlab仿真

目录 1.程序功能描述 2.测试软件版本以及运行结果展示 3.核心程序 4.算法仿真参数 5.算法理论概述 6.参考文献 7.完整程序 1.程序功能描述 通过动态调整节点通信的能量开销&#xff0c;平衡网络负载&#xff0c;延长WSN生命周期。具体通过建立基于距离的能量消耗模型&am…...

对WWDC 2025 Keynote 内容的预测

借助我们以往对苹果公司发展路径的深入研究经验&#xff0c;以及大语言模型的分析能力&#xff0c;我们系统梳理了多年来苹果 WWDC 主题演讲的规律。在 WWDC 2025 即将揭幕之际&#xff0c;我们让 ChatGPT 对今年的 Keynote 内容进行了一个初步预测&#xff0c;聊作存档。等到明…...

【论文笔记】若干矿井粉尘检测算法概述

总的来说&#xff0c;传统机器学习、传统机器学习与深度学习的结合、LSTM等算法所需要的数据集来源于矿井传感器测量的粉尘浓度&#xff0c;通过建立回归模型来预测未来矿井的粉尘浓度。传统机器学习算法性能易受数据中极端值的影响。YOLO等计算机视觉算法所需要的数据集来源于…...

NLP学习路线图(二十三):长短期记忆网络(LSTM)

在自然语言处理(NLP)领域,我们时刻面临着处理序列数据的核心挑战。无论是理解句子的结构、分析文本的情感,还是实现语言的翻译,都需要模型能够捕捉词语之间依时序产生的复杂依赖关系。传统的神经网络结构在处理这种序列依赖时显得力不从心,而循环神经网络(RNN) 曾被视为…...

ArcGIS Pro制作水平横向图例+多级标注

今天介绍下载ArcGIS Pro中如何设置水平横向图例。 之前我们介绍了ArcGIS的横向图例制作&#xff1a;ArcGIS横向、多列图例、顺序重排、符号居中、批量更改图例符号等等&#xff08;ArcGIS出图图例8大技巧&#xff09;&#xff0c;那这次我们看看ArcGIS Pro如何更加快捷的操作。…...

项目部署到Linux上时遇到的错误(Redis,MySQL,无法正确连接,地址占用问题)

Redis无法正确连接 在运行jar包时出现了这样的错误 查询得知问题核心在于Redis连接失败&#xff0c;具体原因是客户端发送了密码认证请求&#xff0c;但Redis服务器未设置密码 1.为Redis设置密码&#xff08;匹配客户端配置&#xff09; 步骤&#xff1a; 1&#xff09;.修…...

短视频矩阵系统文案创作功能开发实践,定制化开发

在短视频行业迅猛发展的当下&#xff0c;企业和个人创作者为了扩大影响力、提升传播效果&#xff0c;纷纷采用短视频矩阵运营策略&#xff0c;同时管理多个平台、多个账号的内容发布。然而&#xff0c;频繁的文案创作需求让运营者疲于应对&#xff0c;如何高效产出高质量文案成…...

Python Ovito统计金刚石结构数量

大家好,我是小马老师。 本文介绍python ovito方法统计金刚石结构的方法。 Ovito Identify diamond structure命令可以识别和统计金刚石结构,但是无法直接输出结构的变化情况。 本文使用python调用ovito包的方法,可以持续统计各步的金刚石结构,具体代码如下: from ovito…...

人工智能--安全大模型训练计划:基于Fine-tuning + LLM Agent

安全大模型训练计划&#xff1a;基于Fine-tuning LLM Agent 1. 构建高质量安全数据集 目标&#xff1a;为安全大模型创建高质量、去偏、符合伦理的训练数据集&#xff0c;涵盖安全相关任务&#xff08;如有害内容检测、隐私保护、道德推理等&#xff09;。 1.1 数据收集 描…...