当前位置: 首页 > news >正文

Tensorflow2.0笔记 - metrics做损失和准确度信息度量

        本笔记主要记录metrics相关的内容,详细内容请参考代码注释,代码本身只使用了Accuracy和Mean。本节的代码基于上篇笔记FashionMnist的代码经过简单修改而来,上篇笔记链接如下:

Tensorflow2.0笔记 - FashionMnist数据集训练-CSDN博客文章浏览阅读339次。本笔记使用FashionMnist数据集,搭建一个5层的神经网络进行训练,并统计测试集的精度。本笔记中FashionMnist数据集是直接下载到本地加载的方式,不涉及用梯子。关于FashionMnist的介绍,请自行百度。https://blog.csdn.net/vivo01/article/details/136921592?spm=1001.2014.3001.5502

#Fashion Mnist数据集本地下载和加载(不用梯子)
#https://blog.csdn.net/scar2016/article/details/115361245 (百度网盘)
#https://blog.csdn.net/weixin_43272781/article/details/110006990 (github)
import tensorflow as tf
from tensorflow import keras
from tensorflow.keras import datasets, layers, optimizers, Sequential, metricstf.__version__#加载fashion mnist数据集
def load_mnist(path, kind='train'):import osimport gzipimport numpy as np"""Load MNIST data from `path`"""labels_path = os.path.join(path,'%s-labels-idx1-ubyte.gz'% kind)images_path = os.path.join(path,'%s-images-idx3-ubyte.gz'% kind)with gzip.open(labels_path, 'rb') as lbpath:labels = np.frombuffer(lbpath.read(), dtype=np.uint8,offset=8)with gzip.open(images_path, 'rb') as imgpath:images = np.frombuffer(imgpath.read(), dtype=np.uint8,offset=16).reshape(len(labels), 784)return images, labels#预处理数据
def preprocess(x, y):x = tf.cast(x, dtype=tf.float32)x = tf.convert_to_tensor(x, dtype=tf.float32) / 255.y = tf.cast(y, dtype=tf.int32)y = tf.convert_to_tensor(y, dtype=tf.int32)return x, y
#训练数据
train_data, train_labels = load_mnist("./datasets")
print(train_data.shape, train_labels.shape)
#测试数据
test_data, test_labels = load_mnist("./datasets", "t10k")
print(test_data.shape, test_labels.shape)batch_size = 128train_db = tf.data.Dataset.from_tensor_slices((train_data, train_labels))
train_db = train_db.map(preprocess).shuffle(10000).batch(batch_size)test_db = tf.data.Dataset.from_tensor_slices((test_data, test_labels))
test_db = test_db.map(preprocess).batch(batch_size)train_db_iter = iter(train_db)
sample = next(train_db_iter)
print('Batch:', sample[0].shape, sample[1].shape)#定义网络模型
model = Sequential([#Layer 1: [b, 784] => [b, 256]layers.Dense(256, activation=tf.nn.relu),#Layer 2: [b, 256] => [b, 128]layers.Dense(128, activation=tf.nn.relu),#Layer 3: [b, 128] => [b, 64]layers.Dense(64, activation=tf.nn.relu),#Layer 4: [b, 64] => [b, 32]layers.Dense(32, activation=tf.nn.relu),#Layer 5: [b, 32] => [b, 10], 输出类别结果layers.Dense(10)
])#编译网络
model.build(input_shape=[None, 28*28])
model.summary()#进行训练
total_epoches = 5
learn_rate = 0.01#Metrics统计
#参考资料:https://zhuanlan.zhihu.com/p/42438077
#1. 新建meter
#acc_meter = metrics.Accuracy()
#loss_meter = metrics.Mean()
#2. 更新状态, update_state()
#loss_meter.update_state(loss)
#acc_meter.update_state(y, pred)
#3.获取结果, result()
#print(step, 'loss:', loss_meter.result().numpy())
#print(step, 'Evaluate Acc:', total_correct/total, acc_meter.result().numpy())
#4.清除度量信息,reset_states()
#loss_meter.reset_states()
#acc_meter.reset_states()#新建准确度和loss度量对象
acc_meter = metrics.Accuracy()
loss_meter = metrics.Mean()optimizer = optimizers.Adam(learning_rate = learn_rate)
for epoch in range(total_epoches):for step, (x,y) in enumerate(train_db):with tf.GradientTape() as tape:logits = model(x)y_onehot = tf.one_hot(y, depth=10)#使用交叉熵作为lossloss_ce = tf.reduce_mean(tf.losses.categorical_crossentropy(y_onehot, logits, from_logits=True))#调用update_state更新loss度量信息loss_meter.update_state(loss_ce)#计算梯度grads = tape.gradient(loss_ce, model.trainable_variables)#更新梯度optimizer.apply_gradients(zip(grads, model.trainable_variables))if step % 100 == 0:print("Epoch[", epoch, "]: step-", step, "\tloss: ", loss_meter.result().numpy())loss_meter.reset_states()#使用测试集进行验证total_correct = 0total_num = 0#清除准确度的统计信息acc_meter.reset_states()for x,y in test_db:logits = model(x)#使用softmax得到各个类别的概率prob = tf.nn.softmax(logits, axis=1)#求出概率最大的结果参数位置,作为预测的分类结果pred = tf.cast(tf.argmax(prob, axis=1), dtype=tf.int32)#比较结果correct = tf.equal(pred, y)correct = tf.reduce_sum(tf.cast(correct, dtype=tf.int32))#计算精度total_correct += int(correct)total_num += x.shape[0]#使用metircs的update_state进行更新acc_meter.update_state(y, pred)acc = total_correct / total_numprint("Epoch[", epoch, "] Manual Accuracy:", acc, " Metrics Accuracy:", acc_meter.result().numpy())

运行结果:

相关文章:

Tensorflow2.0笔记 - metrics做损失和准确度信息度量

本笔记主要记录metrics相关的内容,详细内容请参考代码注释,代码本身只使用了Accuracy和Mean。本节的代码基于上篇笔记FashionMnist的代码经过简单修改而来,上篇笔记链接如下: Tensorflow2.0笔记 - FashionMnist数据集训练-CSDN博…...

LeetCode 面试经典150题 290.单词规律

题目: 给定一种规律 pattern 和一个字符串 s ,判断 s 是否遵循相同的规律。 这里的 遵循 指完全匹配,例如, pattern 里的每个字母和字符串 s 中的每个非空单词之间存在着双向连接的对应规律。 思路:一一映射需要用到…...

【CASS精品教程】CASS中计算四参数和七参数(以RTK数据为例)

文章目录 一、四参数介绍二、四参数计算三、七参数介绍四、四参数、七参数的区别一、四参数介绍 两个不同的二维平面直角坐标系之间转换通常使用四参数模型,四参数适合小范围测区的空间坐标转换,相对于七参数转换的优势在于只需要2个公共已知点就能进行转换,操作简单。 在…...

什么是RISC-V?开源 ISA 如何重塑未来的处理器设计

RISC-V代表了处理器架构的范式转变,特点是其开源模型简化了设计理念并促进了全球community-driven的开发。RISC-V导致了处理器技术发展前进方式的重大转变,提供了一个不受传统复杂性阻碍的全新视角。 RISC-V起源于加州大学伯克利分校的学术起点&#xff…...

展馆设计中展示有哪些要求

1、展示产品特点和功能 产品展示应突出产品的特点、功能和优势。通过使用适当的展示方式和展示环境,使观众能够直观地了解产品的外观、结构、性能等方面。可以使用实物展示、模型、原型、图表、动画等方式,以多种角度和视角展示产品的特点和功能。 2、提…...

python实战之PyQt5桌面软件

一. 演示效果 二. 准备工作 1. 使用pip 下载所需包 pyqt5 2. 下载可视化UI工具 QT Designer 链接:https://pan.baidu.com/s/1ic4S3ocEF90Y4L1GqYHPPA?pwdywct 提取码:ywct 3. 可视化UI工具汉化 把上面的链接打开, 里面有安装和汉化包, 前面的路径还要看…...

Switch 和 PS1 模拟器:3000+ 游戏随心玩 | 开源日报 No.174

Ryujinx/Ryujinx Stars: 26.1k License: MIT Ryujinx 是用 C# 编写的实验性任天堂 Switch 模拟器。 该项目旨在提供出色的准确性和性能、用户友好的界面以及稳定的构建。它已经通过了大约 4050 个测试,其中超过 4000 个可以启动并进入游戏,其中大约 340…...

免费翻译pdf格式论文

进入谷歌翻译网址https://translate.google.com/?slauto&tlzh-CN&opdocs 将需要全文翻译的pdf放进去 选择英文到中文,然后点击翻译 可以选择打开译文或者下载译文,下载译文会下载到电脑上,打开译文会在浏览器打开。...

3D产品可视化SaaS

“我们正在走向衰退吗?” “我们已经陷入衰退了吗?” “我们正在步入衰退。” 过去几个月占据头条的问题和陈述引发了关于市场对每个行业影响的讨论和激烈辩论。 特别是对于科技行业来说,过去几周一直很动荡,围绕费用、增长和裁…...

浙大版《C语言程序设计(第4版)》题目集-习题3-5 三角形判断

给定平面上任意三个点的坐标(x1,y1)、(x2,y2)、(x3,y3),检验它们能否构成三角形。 输入格式: 输入在一行中顺序给出六个[−100,100]范围内的数字,即三个点的坐标x1、y1、x2、y2、x3、y3。 输出格式: 若这3个点不能构成三角形,则在一行中输…...

Java封装、继承、多态和抽象深度解析

在软件工程的世界里,面向对象编程(OOP)是一种编程范式,它使用“对象”来设计软件。对象可以封装数据和方法,以提高代码的复用性、可维护性和可扩展性。Java作为一门面向对象的编程语言,提供了四个基本的面向…...

深度学习每周学习总结P3(天气识别)

🍨 本文为🔗365天深度学习训练营 中的学习记录博客🍖 原作者:K同学啊 | 接辅导、项目定制 数据链接 提取码:o3ix 目录 0. 总结1. 数据导入部分数据导入部分代码详解:a. 数据读取部分a.1 提问:关…...

通过iOS网络抓包工具实现移动应用数据安全监控

摘要 本文将深入探讨iOS平台上常用的网络抓包工具,包括Charles、克魔助手、Thor和Http Catcher,以及通过SSH连接进行抓包的方法。此外,还介绍了克魔开发助手作为iOS应用开发的辅助工具,提供的全方面性能监控和调试功能。 在iOS应…...

Stable Diffusion WebUI 生成参数:脚本(Script)——提示词矩阵、从文本框或文件载入提示词、X/Y/Z图表

本文收录于《AI绘画从入门到精通》专栏,专栏总目录:点这里,订阅后可阅读专栏内所有文章。 大家好,我是水滴~~ 在本篇文章中,我们将深入探讨 Stable Diffusion WebUI 的另一个引人注目的生成参数——脚本(Script)。我们将逐一细说提示词矩阵、从文本框或文件导入提示词,…...

synchronized和volatile的原理及应用

文章目录 synchronized的实现原理及应用升级锁代码示例volatile原理及应用代码示例线程不安全类 synchronized的实现原理及应用 synchronized 是Java中用于实现线程同步的关键字,可以应用于方法或代码块,确保在多线程环境下对共享资源的安全访问。下面是…...

Python 基于 OpenCV 视觉图像处理实战 之 OpenCV 简单实战案例 之九 简单闪烁效果

Python 基于 OpenCV 视觉图像处理实战 之 OpenCV 简单实战案例 之九 简单闪烁效果 目录 Python 基于 OpenCV 视觉图像处理实战 之 OpenCV 简单实战案例 之九 简单闪烁效果 一、简单介绍 二、简单闪烁效果实现原理 三、简单闪烁效果案例实现简单步骤 四、注意事项 一、简单…...

11 开源鸿蒙OpenHarmony轻量系统源码分析

开源鸿蒙轻量系统源码分析 作者将狼才鲸日期2024-03-28 一、前言 之前单独的LiteOS是通过Makefile编译的,当前的开源鸿蒙LiteOS-M和LiteOS-A是通过gn和ninja编译的。 Gitee官方只介绍了LiteOS-M的gn ninja编译的流程,针对M3使用Keil编译的流程可能要参…...

专题:一个自制代码生成器(嵌入式脚本语言)之应用实例

初级代码游戏的专栏介绍与文章目录-CSDN博客 我的github:codetoys,所有代码都将会位于ctfc库中。已经放入库中我会指出在库中的位置。 这些代码大部分以Linux为目标但部分代码是纯C的,可以在任何平台上使用。 专题:一个自制代码…...

Appium设备交互API

设备交互API指的是操作设备系统中的一些固有功能,而非被测程序的功能,例如模拟来电,模拟发送短信,设置网络,切换横竖屏,APP操作,打开通知栏,录屏等。 模拟来电 make_gsm_call(phon…...

Qlib-Server部署

Qlib-Server部署 介绍 构建Qlib服务器,用户可以选择: 一键部署Qlib服务器逐步部署Qlib服务器一键部署 Qlib服务器支持一键部署,用户可以选择以下两种方法之一进行一键部署: 使用docker-compose部署在Azure中部署使用docker-compose进行一键部署 按照以下步骤使用docker…...

SciencePlots——绘制论文中的图片

文章目录 安装一、风格二、1 资源 安装 # 安装最新版 pip install githttps://github.com/garrettj403/SciencePlots.git# 安装稳定版 pip install SciencePlots一、风格 简单好用的深度学习论文绘图专用工具包–Science Plot 二、 1 资源 论文绘图神器来了:一行…...

【Redis技术进阶之路】「原理分析系列开篇」分析客户端和服务端网络诵信交互实现(服务端执行命令请求的过程 - 初始化服务器)

服务端执行命令请求的过程 【专栏简介】【技术大纲】【专栏目标】【目标人群】1. Redis爱好者与社区成员2. 后端开发和系统架构师3. 计算机专业的本科生及研究生 初始化服务器1. 初始化服务器状态结构初始化RedisServer变量 2. 加载相关系统配置和用户配置参数定制化配置参数案…...

质量体系的重要

质量体系是为确保产品、服务或过程质量满足规定要求,由相互关联的要素构成的有机整体。其核心内容可归纳为以下五个方面: 🏛️ 一、组织架构与职责 质量体系明确组织内各部门、岗位的职责与权限,形成层级清晰的管理网络&#xf…...

前端开发面试题总结-JavaScript篇(一)

文章目录 JavaScript高频问答一、作用域与闭包1.什么是闭包(Closure)?闭包有什么应用场景和潜在问题?2.解释 JavaScript 的作用域链(Scope Chain) 二、原型与继承3.原型链是什么?如何实现继承&a…...

微信小程序云开发平台MySQL的连接方式

注:微信小程序云开发平台指的是腾讯云开发 先给结论:微信小程序云开发平台的MySQL,无法通过获取数据库连接信息的方式进行连接,连接只能通过云开发的SDK连接,具体要参考官方文档: 为什么? 因为…...

实现弹窗随键盘上移居中

实现弹窗随键盘上移的核心思路 在Android中&#xff0c;可以通过监听键盘的显示和隐藏事件&#xff0c;动态调整弹窗的位置。关键点在于获取键盘高度&#xff0c;并计算剩余屏幕空间以重新定位弹窗。 // 在Activity或Fragment中设置键盘监听 val rootView findViewById<V…...

精益数据分析(97/126):邮件营销与用户参与度的关键指标优化指南

精益数据分析&#xff08;97/126&#xff09;&#xff1a;邮件营销与用户参与度的关键指标优化指南 在数字化营销时代&#xff0c;邮件列表效度、用户参与度和网站性能等指标往往决定着创业公司的增长成败。今天&#xff0c;我们将深入解析邮件打开率、网站可用性、页面参与时…...

C# 求圆面积的程序(Program to find area of a circle)

给定半径r&#xff0c;求圆的面积。圆的面积应精确到小数点后5位。 例子&#xff1a; 输入&#xff1a;r 5 输出&#xff1a;78.53982 解释&#xff1a;由于面积 PI * r * r 3.14159265358979323846 * 5 * 5 78.53982&#xff0c;因为我们只保留小数点后 5 位数字。 输…...

R语言速释制剂QBD解决方案之三

本文是《Quality by Design for ANDAs: An Example for Immediate-Release Dosage Forms》第一个处方的R语言解决方案。 第一个处方研究评估原料药粒径分布、MCC/Lactose比例、崩解剂用量对制剂CQAs的影响。 第二处方研究用于理解颗粒外加硬脂酸镁和滑石粉对片剂质量和可生产…...

人机融合智能 | “人智交互”跨学科新领域

本文系统地提出基于“以人为中心AI(HCAI)”理念的人-人工智能交互(人智交互)这一跨学科新领域及框架,定义人智交互领域的理念、基本理论和关键问题、方法、开发流程和参与团队等,阐述提出人智交互新领域的意义。然后,提出人智交互研究的三种新范式取向以及它们的意义。最后,总结…...