当前位置: 首页 > news >正文

Tensorflow2.0笔记 - metrics做损失和准确度信息度量

        本笔记主要记录metrics相关的内容,详细内容请参考代码注释,代码本身只使用了Accuracy和Mean。本节的代码基于上篇笔记FashionMnist的代码经过简单修改而来,上篇笔记链接如下:

Tensorflow2.0笔记 - FashionMnist数据集训练-CSDN博客文章浏览阅读339次。本笔记使用FashionMnist数据集,搭建一个5层的神经网络进行训练,并统计测试集的精度。本笔记中FashionMnist数据集是直接下载到本地加载的方式,不涉及用梯子。关于FashionMnist的介绍,请自行百度。https://blog.csdn.net/vivo01/article/details/136921592?spm=1001.2014.3001.5502

#Fashion Mnist数据集本地下载和加载(不用梯子)
#https://blog.csdn.net/scar2016/article/details/115361245 (百度网盘)
#https://blog.csdn.net/weixin_43272781/article/details/110006990 (github)
import tensorflow as tf
from tensorflow import keras
from tensorflow.keras import datasets, layers, optimizers, Sequential, metricstf.__version__#加载fashion mnist数据集
def load_mnist(path, kind='train'):import osimport gzipimport numpy as np"""Load MNIST data from `path`"""labels_path = os.path.join(path,'%s-labels-idx1-ubyte.gz'% kind)images_path = os.path.join(path,'%s-images-idx3-ubyte.gz'% kind)with gzip.open(labels_path, 'rb') as lbpath:labels = np.frombuffer(lbpath.read(), dtype=np.uint8,offset=8)with gzip.open(images_path, 'rb') as imgpath:images = np.frombuffer(imgpath.read(), dtype=np.uint8,offset=16).reshape(len(labels), 784)return images, labels#预处理数据
def preprocess(x, y):x = tf.cast(x, dtype=tf.float32)x = tf.convert_to_tensor(x, dtype=tf.float32) / 255.y = tf.cast(y, dtype=tf.int32)y = tf.convert_to_tensor(y, dtype=tf.int32)return x, y
#训练数据
train_data, train_labels = load_mnist("./datasets")
print(train_data.shape, train_labels.shape)
#测试数据
test_data, test_labels = load_mnist("./datasets", "t10k")
print(test_data.shape, test_labels.shape)batch_size = 128train_db = tf.data.Dataset.from_tensor_slices((train_data, train_labels))
train_db = train_db.map(preprocess).shuffle(10000).batch(batch_size)test_db = tf.data.Dataset.from_tensor_slices((test_data, test_labels))
test_db = test_db.map(preprocess).batch(batch_size)train_db_iter = iter(train_db)
sample = next(train_db_iter)
print('Batch:', sample[0].shape, sample[1].shape)#定义网络模型
model = Sequential([#Layer 1: [b, 784] => [b, 256]layers.Dense(256, activation=tf.nn.relu),#Layer 2: [b, 256] => [b, 128]layers.Dense(128, activation=tf.nn.relu),#Layer 3: [b, 128] => [b, 64]layers.Dense(64, activation=tf.nn.relu),#Layer 4: [b, 64] => [b, 32]layers.Dense(32, activation=tf.nn.relu),#Layer 5: [b, 32] => [b, 10], 输出类别结果layers.Dense(10)
])#编译网络
model.build(input_shape=[None, 28*28])
model.summary()#进行训练
total_epoches = 5
learn_rate = 0.01#Metrics统计
#参考资料:https://zhuanlan.zhihu.com/p/42438077
#1. 新建meter
#acc_meter = metrics.Accuracy()
#loss_meter = metrics.Mean()
#2. 更新状态, update_state()
#loss_meter.update_state(loss)
#acc_meter.update_state(y, pred)
#3.获取结果, result()
#print(step, 'loss:', loss_meter.result().numpy())
#print(step, 'Evaluate Acc:', total_correct/total, acc_meter.result().numpy())
#4.清除度量信息,reset_states()
#loss_meter.reset_states()
#acc_meter.reset_states()#新建准确度和loss度量对象
acc_meter = metrics.Accuracy()
loss_meter = metrics.Mean()optimizer = optimizers.Adam(learning_rate = learn_rate)
for epoch in range(total_epoches):for step, (x,y) in enumerate(train_db):with tf.GradientTape() as tape:logits = model(x)y_onehot = tf.one_hot(y, depth=10)#使用交叉熵作为lossloss_ce = tf.reduce_mean(tf.losses.categorical_crossentropy(y_onehot, logits, from_logits=True))#调用update_state更新loss度量信息loss_meter.update_state(loss_ce)#计算梯度grads = tape.gradient(loss_ce, model.trainable_variables)#更新梯度optimizer.apply_gradients(zip(grads, model.trainable_variables))if step % 100 == 0:print("Epoch[", epoch, "]: step-", step, "\tloss: ", loss_meter.result().numpy())loss_meter.reset_states()#使用测试集进行验证total_correct = 0total_num = 0#清除准确度的统计信息acc_meter.reset_states()for x,y in test_db:logits = model(x)#使用softmax得到各个类别的概率prob = tf.nn.softmax(logits, axis=1)#求出概率最大的结果参数位置,作为预测的分类结果pred = tf.cast(tf.argmax(prob, axis=1), dtype=tf.int32)#比较结果correct = tf.equal(pred, y)correct = tf.reduce_sum(tf.cast(correct, dtype=tf.int32))#计算精度total_correct += int(correct)total_num += x.shape[0]#使用metircs的update_state进行更新acc_meter.update_state(y, pred)acc = total_correct / total_numprint("Epoch[", epoch, "] Manual Accuracy:", acc, " Metrics Accuracy:", acc_meter.result().numpy())

运行结果:

相关文章:

Tensorflow2.0笔记 - metrics做损失和准确度信息度量

本笔记主要记录metrics相关的内容,详细内容请参考代码注释,代码本身只使用了Accuracy和Mean。本节的代码基于上篇笔记FashionMnist的代码经过简单修改而来,上篇笔记链接如下: Tensorflow2.0笔记 - FashionMnist数据集训练-CSDN博…...

LeetCode 面试经典150题 290.单词规律

题目: 给定一种规律 pattern 和一个字符串 s ,判断 s 是否遵循相同的规律。 这里的 遵循 指完全匹配,例如, pattern 里的每个字母和字符串 s 中的每个非空单词之间存在着双向连接的对应规律。 思路:一一映射需要用到…...

【CASS精品教程】CASS中计算四参数和七参数(以RTK数据为例)

文章目录 一、四参数介绍二、四参数计算三、七参数介绍四、四参数、七参数的区别一、四参数介绍 两个不同的二维平面直角坐标系之间转换通常使用四参数模型,四参数适合小范围测区的空间坐标转换,相对于七参数转换的优势在于只需要2个公共已知点就能进行转换,操作简单。 在…...

什么是RISC-V?开源 ISA 如何重塑未来的处理器设计

RISC-V代表了处理器架构的范式转变,特点是其开源模型简化了设计理念并促进了全球community-driven的开发。RISC-V导致了处理器技术发展前进方式的重大转变,提供了一个不受传统复杂性阻碍的全新视角。 RISC-V起源于加州大学伯克利分校的学术起点&#xff…...

展馆设计中展示有哪些要求

1、展示产品特点和功能 产品展示应突出产品的特点、功能和优势。通过使用适当的展示方式和展示环境,使观众能够直观地了解产品的外观、结构、性能等方面。可以使用实物展示、模型、原型、图表、动画等方式,以多种角度和视角展示产品的特点和功能。 2、提…...

python实战之PyQt5桌面软件

一. 演示效果 二. 准备工作 1. 使用pip 下载所需包 pyqt5 2. 下载可视化UI工具 QT Designer 链接:https://pan.baidu.com/s/1ic4S3ocEF90Y4L1GqYHPPA?pwdywct 提取码:ywct 3. 可视化UI工具汉化 把上面的链接打开, 里面有安装和汉化包, 前面的路径还要看…...

Switch 和 PS1 模拟器:3000+ 游戏随心玩 | 开源日报 No.174

Ryujinx/Ryujinx Stars: 26.1k License: MIT Ryujinx 是用 C# 编写的实验性任天堂 Switch 模拟器。 该项目旨在提供出色的准确性和性能、用户友好的界面以及稳定的构建。它已经通过了大约 4050 个测试,其中超过 4000 个可以启动并进入游戏,其中大约 340…...

免费翻译pdf格式论文

进入谷歌翻译网址https://translate.google.com/?slauto&tlzh-CN&opdocs 将需要全文翻译的pdf放进去 选择英文到中文,然后点击翻译 可以选择打开译文或者下载译文,下载译文会下载到电脑上,打开译文会在浏览器打开。...

3D产品可视化SaaS

“我们正在走向衰退吗?” “我们已经陷入衰退了吗?” “我们正在步入衰退。” 过去几个月占据头条的问题和陈述引发了关于市场对每个行业影响的讨论和激烈辩论。 特别是对于科技行业来说,过去几周一直很动荡,围绕费用、增长和裁…...

浙大版《C语言程序设计(第4版)》题目集-习题3-5 三角形判断

给定平面上任意三个点的坐标(x1,y1)、(x2,y2)、(x3,y3),检验它们能否构成三角形。 输入格式: 输入在一行中顺序给出六个[−100,100]范围内的数字,即三个点的坐标x1、y1、x2、y2、x3、y3。 输出格式: 若这3个点不能构成三角形,则在一行中输…...

Java封装、继承、多态和抽象深度解析

在软件工程的世界里,面向对象编程(OOP)是一种编程范式,它使用“对象”来设计软件。对象可以封装数据和方法,以提高代码的复用性、可维护性和可扩展性。Java作为一门面向对象的编程语言,提供了四个基本的面向…...

深度学习每周学习总结P3(天气识别)

🍨 本文为🔗365天深度学习训练营 中的学习记录博客🍖 原作者:K同学啊 | 接辅导、项目定制 数据链接 提取码:o3ix 目录 0. 总结1. 数据导入部分数据导入部分代码详解:a. 数据读取部分a.1 提问:关…...

通过iOS网络抓包工具实现移动应用数据安全监控

摘要 本文将深入探讨iOS平台上常用的网络抓包工具,包括Charles、克魔助手、Thor和Http Catcher,以及通过SSH连接进行抓包的方法。此外,还介绍了克魔开发助手作为iOS应用开发的辅助工具,提供的全方面性能监控和调试功能。 在iOS应…...

Stable Diffusion WebUI 生成参数:脚本(Script)——提示词矩阵、从文本框或文件载入提示词、X/Y/Z图表

本文收录于《AI绘画从入门到精通》专栏,专栏总目录:点这里,订阅后可阅读专栏内所有文章。 大家好,我是水滴~~ 在本篇文章中,我们将深入探讨 Stable Diffusion WebUI 的另一个引人注目的生成参数——脚本(Script)。我们将逐一细说提示词矩阵、从文本框或文件导入提示词,…...

synchronized和volatile的原理及应用

文章目录 synchronized的实现原理及应用升级锁代码示例volatile原理及应用代码示例线程不安全类 synchronized的实现原理及应用 synchronized 是Java中用于实现线程同步的关键字,可以应用于方法或代码块,确保在多线程环境下对共享资源的安全访问。下面是…...

Python 基于 OpenCV 视觉图像处理实战 之 OpenCV 简单实战案例 之九 简单闪烁效果

Python 基于 OpenCV 视觉图像处理实战 之 OpenCV 简单实战案例 之九 简单闪烁效果 目录 Python 基于 OpenCV 视觉图像处理实战 之 OpenCV 简单实战案例 之九 简单闪烁效果 一、简单介绍 二、简单闪烁效果实现原理 三、简单闪烁效果案例实现简单步骤 四、注意事项 一、简单…...

11 开源鸿蒙OpenHarmony轻量系统源码分析

开源鸿蒙轻量系统源码分析 作者将狼才鲸日期2024-03-28 一、前言 之前单独的LiteOS是通过Makefile编译的,当前的开源鸿蒙LiteOS-M和LiteOS-A是通过gn和ninja编译的。 Gitee官方只介绍了LiteOS-M的gn ninja编译的流程,针对M3使用Keil编译的流程可能要参…...

专题:一个自制代码生成器(嵌入式脚本语言)之应用实例

初级代码游戏的专栏介绍与文章目录-CSDN博客 我的github:codetoys,所有代码都将会位于ctfc库中。已经放入库中我会指出在库中的位置。 这些代码大部分以Linux为目标但部分代码是纯C的,可以在任何平台上使用。 专题:一个自制代码…...

Appium设备交互API

设备交互API指的是操作设备系统中的一些固有功能,而非被测程序的功能,例如模拟来电,模拟发送短信,设置网络,切换横竖屏,APP操作,打开通知栏,录屏等。 模拟来电 make_gsm_call(phon…...

Qlib-Server部署

Qlib-Server部署 介绍 构建Qlib服务器,用户可以选择: 一键部署Qlib服务器逐步部署Qlib服务器一键部署 Qlib服务器支持一键部署,用户可以选择以下两种方法之一进行一键部署: 使用docker-compose部署在Azure中部署使用docker-compose进行一键部署 按照以下步骤使用docker…...

iOS 26 携众系统重磅更新,但“苹果智能”仍与国行无缘

美国西海岸的夏天,再次被苹果点燃。一年一度的全球开发者大会 WWDC25 如期而至,这不仅是开发者的盛宴,更是全球数亿苹果用户翘首以盼的科技春晚。今年,苹果依旧为我们带来了全家桶式的系统更新,包括 iOS 26、iPadOS 26…...

Cesium1.95中高性能加载1500个点

一、基本方式&#xff1a; 图标使用.png比.svg性能要好 <template><div id"cesiumContainer"></div><div class"toolbar"><button id"resetButton">重新生成点</button><span id"countDisplay&qu…...

相机从app启动流程

一、流程框架图 二、具体流程分析 1、得到cameralist和对应的静态信息 目录如下: 重点代码分析: 启动相机前,先要通过getCameraIdList获取camera的个数以及id,然后可以通过getCameraCharacteristics获取对应id camera的capabilities(静态信息)进行一些openCamera前的…...

管理学院权限管理系统开发总结

文章目录 &#x1f393; 管理学院权限管理系统开发总结 - 现代化Web应用实践之路&#x1f4dd; 项目概述&#x1f3d7;️ 技术架构设计后端技术栈前端技术栈 &#x1f4a1; 核心功能特性1. 用户管理模块2. 权限管理系统3. 统计报表功能4. 用户体验优化 &#x1f5c4;️ 数据库设…...

视频行为标注工具BehaviLabel(源码+使用介绍+Windows.Exe版本)

前言&#xff1a; 最近在做行为检测相关的模型&#xff0c;用的是时空图卷积网络&#xff08;STGCN&#xff09;&#xff0c;但原有kinetic-400数据集数据质量较低&#xff0c;需要进行细粒度的标注&#xff0c;同时粗略搜了下已有开源工具基本都集中于图像分割这块&#xff0c…...

JVM 内存结构 详解

内存结构 运行时数据区&#xff1a; Java虚拟机在运行Java程序过程中管理的内存区域。 程序计数器&#xff1a; ​ 线程私有&#xff0c;程序控制流的指示器&#xff0c;分支、循环、跳转、异常处理、线程恢复等基础功能都依赖这个计数器完成。 ​ 每个线程都有一个程序计数…...

Java编程之桥接模式

定义 桥接模式&#xff08;Bridge Pattern&#xff09;属于结构型设计模式&#xff0c;它的核心意图是将抽象部分与实现部分分离&#xff0c;使它们可以独立地变化。这种模式通过组合关系来替代继承关系&#xff0c;从而降低了抽象和实现这两个可变维度之间的耦合度。 用例子…...

tomcat入门

1 tomcat 是什么 apache开发的web服务器可以为java web程序提供运行环境tomcat是一款高效&#xff0c;稳定&#xff0c;易于使用的web服务器tomcathttp服务器Servlet服务器 2 tomcat 目录介绍 -bin #存放tomcat的脚本 -conf #存放tomcat的配置文件 ---catalina.policy #to…...

在 Spring Boot 项目里,MYSQL中json类型字段使用

前言&#xff1a; 因为程序特殊需求导致&#xff0c;需要mysql数据库存储json类型数据&#xff0c;因此记录一下使用流程 1.java实体中新增字段 private List<User> users 2.增加mybatis-plus注解 TableField(typeHandler FastjsonTypeHandler.class) private Lis…...

Kafka主题运维全指南:从基础配置到故障处理

#作者&#xff1a;张桐瑞 文章目录 主题日常管理1. 修改主题分区。2. 修改主题级别参数。3. 变更副本数。4. 修改主题限速。5.主题分区迁移。6. 常见主题错误处理常见错误1&#xff1a;主题删除失败。常见错误2&#xff1a;__consumer_offsets占用太多的磁盘。 主题日常管理 …...