当前位置: 首页 > news >正文

Tensorflow2.0笔记 - metrics做损失和准确度信息度量

        本笔记主要记录metrics相关的内容,详细内容请参考代码注释,代码本身只使用了Accuracy和Mean。本节的代码基于上篇笔记FashionMnist的代码经过简单修改而来,上篇笔记链接如下:

Tensorflow2.0笔记 - FashionMnist数据集训练-CSDN博客文章浏览阅读339次。本笔记使用FashionMnist数据集,搭建一个5层的神经网络进行训练,并统计测试集的精度。本笔记中FashionMnist数据集是直接下载到本地加载的方式,不涉及用梯子。关于FashionMnist的介绍,请自行百度。https://blog.csdn.net/vivo01/article/details/136921592?spm=1001.2014.3001.5502

#Fashion Mnist数据集本地下载和加载(不用梯子)
#https://blog.csdn.net/scar2016/article/details/115361245 (百度网盘)
#https://blog.csdn.net/weixin_43272781/article/details/110006990 (github)
import tensorflow as tf
from tensorflow import keras
from tensorflow.keras import datasets, layers, optimizers, Sequential, metricstf.__version__#加载fashion mnist数据集
def load_mnist(path, kind='train'):import osimport gzipimport numpy as np"""Load MNIST data from `path`"""labels_path = os.path.join(path,'%s-labels-idx1-ubyte.gz'% kind)images_path = os.path.join(path,'%s-images-idx3-ubyte.gz'% kind)with gzip.open(labels_path, 'rb') as lbpath:labels = np.frombuffer(lbpath.read(), dtype=np.uint8,offset=8)with gzip.open(images_path, 'rb') as imgpath:images = np.frombuffer(imgpath.read(), dtype=np.uint8,offset=16).reshape(len(labels), 784)return images, labels#预处理数据
def preprocess(x, y):x = tf.cast(x, dtype=tf.float32)x = tf.convert_to_tensor(x, dtype=tf.float32) / 255.y = tf.cast(y, dtype=tf.int32)y = tf.convert_to_tensor(y, dtype=tf.int32)return x, y
#训练数据
train_data, train_labels = load_mnist("./datasets")
print(train_data.shape, train_labels.shape)
#测试数据
test_data, test_labels = load_mnist("./datasets", "t10k")
print(test_data.shape, test_labels.shape)batch_size = 128train_db = tf.data.Dataset.from_tensor_slices((train_data, train_labels))
train_db = train_db.map(preprocess).shuffle(10000).batch(batch_size)test_db = tf.data.Dataset.from_tensor_slices((test_data, test_labels))
test_db = test_db.map(preprocess).batch(batch_size)train_db_iter = iter(train_db)
sample = next(train_db_iter)
print('Batch:', sample[0].shape, sample[1].shape)#定义网络模型
model = Sequential([#Layer 1: [b, 784] => [b, 256]layers.Dense(256, activation=tf.nn.relu),#Layer 2: [b, 256] => [b, 128]layers.Dense(128, activation=tf.nn.relu),#Layer 3: [b, 128] => [b, 64]layers.Dense(64, activation=tf.nn.relu),#Layer 4: [b, 64] => [b, 32]layers.Dense(32, activation=tf.nn.relu),#Layer 5: [b, 32] => [b, 10], 输出类别结果layers.Dense(10)
])#编译网络
model.build(input_shape=[None, 28*28])
model.summary()#进行训练
total_epoches = 5
learn_rate = 0.01#Metrics统计
#参考资料:https://zhuanlan.zhihu.com/p/42438077
#1. 新建meter
#acc_meter = metrics.Accuracy()
#loss_meter = metrics.Mean()
#2. 更新状态, update_state()
#loss_meter.update_state(loss)
#acc_meter.update_state(y, pred)
#3.获取结果, result()
#print(step, 'loss:', loss_meter.result().numpy())
#print(step, 'Evaluate Acc:', total_correct/total, acc_meter.result().numpy())
#4.清除度量信息,reset_states()
#loss_meter.reset_states()
#acc_meter.reset_states()#新建准确度和loss度量对象
acc_meter = metrics.Accuracy()
loss_meter = metrics.Mean()optimizer = optimizers.Adam(learning_rate = learn_rate)
for epoch in range(total_epoches):for step, (x,y) in enumerate(train_db):with tf.GradientTape() as tape:logits = model(x)y_onehot = tf.one_hot(y, depth=10)#使用交叉熵作为lossloss_ce = tf.reduce_mean(tf.losses.categorical_crossentropy(y_onehot, logits, from_logits=True))#调用update_state更新loss度量信息loss_meter.update_state(loss_ce)#计算梯度grads = tape.gradient(loss_ce, model.trainable_variables)#更新梯度optimizer.apply_gradients(zip(grads, model.trainable_variables))if step % 100 == 0:print("Epoch[", epoch, "]: step-", step, "\tloss: ", loss_meter.result().numpy())loss_meter.reset_states()#使用测试集进行验证total_correct = 0total_num = 0#清除准确度的统计信息acc_meter.reset_states()for x,y in test_db:logits = model(x)#使用softmax得到各个类别的概率prob = tf.nn.softmax(logits, axis=1)#求出概率最大的结果参数位置,作为预测的分类结果pred = tf.cast(tf.argmax(prob, axis=1), dtype=tf.int32)#比较结果correct = tf.equal(pred, y)correct = tf.reduce_sum(tf.cast(correct, dtype=tf.int32))#计算精度total_correct += int(correct)total_num += x.shape[0]#使用metircs的update_state进行更新acc_meter.update_state(y, pred)acc = total_correct / total_numprint("Epoch[", epoch, "] Manual Accuracy:", acc, " Metrics Accuracy:", acc_meter.result().numpy())

运行结果:

相关文章:

Tensorflow2.0笔记 - metrics做损失和准确度信息度量

本笔记主要记录metrics相关的内容,详细内容请参考代码注释,代码本身只使用了Accuracy和Mean。本节的代码基于上篇笔记FashionMnist的代码经过简单修改而来,上篇笔记链接如下: Tensorflow2.0笔记 - FashionMnist数据集训练-CSDN博…...

LeetCode 面试经典150题 290.单词规律

题目: 给定一种规律 pattern 和一个字符串 s ,判断 s 是否遵循相同的规律。 这里的 遵循 指完全匹配,例如, pattern 里的每个字母和字符串 s 中的每个非空单词之间存在着双向连接的对应规律。 思路:一一映射需要用到…...

【CASS精品教程】CASS中计算四参数和七参数(以RTK数据为例)

文章目录 一、四参数介绍二、四参数计算三、七参数介绍四、四参数、七参数的区别一、四参数介绍 两个不同的二维平面直角坐标系之间转换通常使用四参数模型,四参数适合小范围测区的空间坐标转换,相对于七参数转换的优势在于只需要2个公共已知点就能进行转换,操作简单。 在…...

什么是RISC-V?开源 ISA 如何重塑未来的处理器设计

RISC-V代表了处理器架构的范式转变,特点是其开源模型简化了设计理念并促进了全球community-driven的开发。RISC-V导致了处理器技术发展前进方式的重大转变,提供了一个不受传统复杂性阻碍的全新视角。 RISC-V起源于加州大学伯克利分校的学术起点&#xff…...

展馆设计中展示有哪些要求

1、展示产品特点和功能 产品展示应突出产品的特点、功能和优势。通过使用适当的展示方式和展示环境,使观众能够直观地了解产品的外观、结构、性能等方面。可以使用实物展示、模型、原型、图表、动画等方式,以多种角度和视角展示产品的特点和功能。 2、提…...

python实战之PyQt5桌面软件

一. 演示效果 二. 准备工作 1. 使用pip 下载所需包 pyqt5 2. 下载可视化UI工具 QT Designer 链接:https://pan.baidu.com/s/1ic4S3ocEF90Y4L1GqYHPPA?pwdywct 提取码:ywct 3. 可视化UI工具汉化 把上面的链接打开, 里面有安装和汉化包, 前面的路径还要看…...

Switch 和 PS1 模拟器:3000+ 游戏随心玩 | 开源日报 No.174

Ryujinx/Ryujinx Stars: 26.1k License: MIT Ryujinx 是用 C# 编写的实验性任天堂 Switch 模拟器。 该项目旨在提供出色的准确性和性能、用户友好的界面以及稳定的构建。它已经通过了大约 4050 个测试,其中超过 4000 个可以启动并进入游戏,其中大约 340…...

免费翻译pdf格式论文

进入谷歌翻译网址https://translate.google.com/?slauto&tlzh-CN&opdocs 将需要全文翻译的pdf放进去 选择英文到中文,然后点击翻译 可以选择打开译文或者下载译文,下载译文会下载到电脑上,打开译文会在浏览器打开。...

3D产品可视化SaaS

“我们正在走向衰退吗?” “我们已经陷入衰退了吗?” “我们正在步入衰退。” 过去几个月占据头条的问题和陈述引发了关于市场对每个行业影响的讨论和激烈辩论。 特别是对于科技行业来说,过去几周一直很动荡,围绕费用、增长和裁…...

浙大版《C语言程序设计(第4版)》题目集-习题3-5 三角形判断

给定平面上任意三个点的坐标(x1,y1)、(x2,y2)、(x3,y3),检验它们能否构成三角形。 输入格式: 输入在一行中顺序给出六个[−100,100]范围内的数字,即三个点的坐标x1、y1、x2、y2、x3、y3。 输出格式: 若这3个点不能构成三角形,则在一行中输…...

Java封装、继承、多态和抽象深度解析

在软件工程的世界里,面向对象编程(OOP)是一种编程范式,它使用“对象”来设计软件。对象可以封装数据和方法,以提高代码的复用性、可维护性和可扩展性。Java作为一门面向对象的编程语言,提供了四个基本的面向…...

深度学习每周学习总结P3(天气识别)

🍨 本文为🔗365天深度学习训练营 中的学习记录博客🍖 原作者:K同学啊 | 接辅导、项目定制 数据链接 提取码:o3ix 目录 0. 总结1. 数据导入部分数据导入部分代码详解:a. 数据读取部分a.1 提问:关…...

通过iOS网络抓包工具实现移动应用数据安全监控

摘要 本文将深入探讨iOS平台上常用的网络抓包工具,包括Charles、克魔助手、Thor和Http Catcher,以及通过SSH连接进行抓包的方法。此外,还介绍了克魔开发助手作为iOS应用开发的辅助工具,提供的全方面性能监控和调试功能。 在iOS应…...

Stable Diffusion WebUI 生成参数:脚本(Script)——提示词矩阵、从文本框或文件载入提示词、X/Y/Z图表

本文收录于《AI绘画从入门到精通》专栏,专栏总目录:点这里,订阅后可阅读专栏内所有文章。 大家好,我是水滴~~ 在本篇文章中,我们将深入探讨 Stable Diffusion WebUI 的另一个引人注目的生成参数——脚本(Script)。我们将逐一细说提示词矩阵、从文本框或文件导入提示词,…...

synchronized和volatile的原理及应用

文章目录 synchronized的实现原理及应用升级锁代码示例volatile原理及应用代码示例线程不安全类 synchronized的实现原理及应用 synchronized 是Java中用于实现线程同步的关键字,可以应用于方法或代码块,确保在多线程环境下对共享资源的安全访问。下面是…...

Python 基于 OpenCV 视觉图像处理实战 之 OpenCV 简单实战案例 之九 简单闪烁效果

Python 基于 OpenCV 视觉图像处理实战 之 OpenCV 简单实战案例 之九 简单闪烁效果 目录 Python 基于 OpenCV 视觉图像处理实战 之 OpenCV 简单实战案例 之九 简单闪烁效果 一、简单介绍 二、简单闪烁效果实现原理 三、简单闪烁效果案例实现简单步骤 四、注意事项 一、简单…...

11 开源鸿蒙OpenHarmony轻量系统源码分析

开源鸿蒙轻量系统源码分析 作者将狼才鲸日期2024-03-28 一、前言 之前单独的LiteOS是通过Makefile编译的,当前的开源鸿蒙LiteOS-M和LiteOS-A是通过gn和ninja编译的。 Gitee官方只介绍了LiteOS-M的gn ninja编译的流程,针对M3使用Keil编译的流程可能要参…...

专题:一个自制代码生成器(嵌入式脚本语言)之应用实例

初级代码游戏的专栏介绍与文章目录-CSDN博客 我的github:codetoys,所有代码都将会位于ctfc库中。已经放入库中我会指出在库中的位置。 这些代码大部分以Linux为目标但部分代码是纯C的,可以在任何平台上使用。 专题:一个自制代码…...

Appium设备交互API

设备交互API指的是操作设备系统中的一些固有功能,而非被测程序的功能,例如模拟来电,模拟发送短信,设置网络,切换横竖屏,APP操作,打开通知栏,录屏等。 模拟来电 make_gsm_call(phon…...

Qlib-Server部署

Qlib-Server部署 介绍 构建Qlib服务器,用户可以选择: 一键部署Qlib服务器逐步部署Qlib服务器一键部署 Qlib服务器支持一键部署,用户可以选择以下两种方法之一进行一键部署: 使用docker-compose部署在Azure中部署使用docker-compose进行一键部署 按照以下步骤使用docker…...

mongodb源码分析session执行handleRequest命令find过程

mongo/transport/service_state_machine.cpp已经分析startSession创建ASIOSession过程,并且验证connection是否超过限制ASIOSession和connection是循环接受客户端命令,把数据流转换成Message,状态转变流程是:State::Created 》 St…...

微信小程序 - 手机震动

一、界面 <button type"primary" bindtap"shortVibrate">短震动</button> <button type"primary" bindtap"longVibrate">长震动</button> 二、js逻辑代码 注&#xff1a;文档 https://developers.weixin.qq…...

ardupilot 开发环境eclipse 中import 缺少C++

目录 文章目录 目录摘要1.修复过程摘要 本节主要解决ardupilot 开发环境eclipse 中import 缺少C++,无法导入ardupilot代码,会引起查看不方便的问题。如下图所示 1.修复过程 0.安装ubuntu 软件中自带的eclipse 1.打开eclipse—Help—install new software 2.在 Work with中…...

工业自动化时代的精准装配革新:迁移科技3D视觉系统如何重塑机器人定位装配

AI3D视觉的工业赋能者 迁移科技成立于2017年&#xff0c;作为行业领先的3D工业相机及视觉系统供应商&#xff0c;累计完成数亿元融资。其核心技术覆盖硬件设计、算法优化及软件集成&#xff0c;通过稳定、易用、高回报的AI3D视觉系统&#xff0c;为汽车、新能源、金属制造等行…...

【JavaSE】绘图与事件入门学习笔记

-Java绘图坐标体系 坐标体系-介绍 坐标原点位于左上角&#xff0c;以像素为单位。 在Java坐标系中,第一个是x坐标,表示当前位置为水平方向&#xff0c;距离坐标原点x个像素;第二个是y坐标&#xff0c;表示当前位置为垂直方向&#xff0c;距离坐标原点y个像素。 坐标体系-像素 …...

【C++从零实现Json-Rpc框架】第六弹 —— 服务端模块划分

一、项目背景回顾 前五弹完成了Json-Rpc协议解析、请求处理、客户端调用等基础模块搭建。 本弹重点聚焦于服务端的模块划分与架构设计&#xff0c;提升代码结构的可维护性与扩展性。 二、服务端模块设计目标 高内聚低耦合&#xff1a;各模块职责清晰&#xff0c;便于独立开发…...

管理学院权限管理系统开发总结

文章目录 &#x1f393; 管理学院权限管理系统开发总结 - 现代化Web应用实践之路&#x1f4dd; 项目概述&#x1f3d7;️ 技术架构设计后端技术栈前端技术栈 &#x1f4a1; 核心功能特性1. 用户管理模块2. 权限管理系统3. 统计报表功能4. 用户体验优化 &#x1f5c4;️ 数据库设…...

基于SpringBoot在线拍卖系统的设计和实现

摘 要 随着社会的发展&#xff0c;社会的各行各业都在利用信息化时代的优势。计算机的优势和普及使得各种信息系统的开发成为必需。 在线拍卖系统&#xff0c;主要的模块包括管理员&#xff1b;首页、个人中心、用户管理、商品类型管理、拍卖商品管理、历史竞拍管理、竞拍订单…...

Mysql8 忘记密码重置,以及问题解决

1.使用免密登录 找到配置MySQL文件&#xff0c;我的文件路径是/etc/mysql/my.cnf&#xff0c;有的人的是/etc/mysql/mysql.cnf 在里最后加入 skip-grant-tables重启MySQL服务 service mysql restartShutting down MySQL… SUCCESS! Starting MySQL… SUCCESS! 重启成功 2.登…...

C# 表达式和运算符(求值顺序)

求值顺序 表达式可以由许多嵌套的子表达式构成。子表达式的求值顺序可以使表达式的最终值发生 变化。 例如&#xff0c;已知表达式3*52&#xff0c;依照子表达式的求值顺序&#xff0c;有两种可能的结果&#xff0c;如图9-3所示。 如果乘法先执行&#xff0c;结果是17。如果5…...