Tensorflow2.0:CNN、ResNet实现MNIST分类识别
以下仅是个人的学习笔记 ,内容可能是错误
CNN:
import tensorflow as tf
from tensorflow import keras
from tensorflow.keras import layers# 导入数据
(x_train, y_train), (x_test, y_test) = keras.datasets.mnist.load_data()# 数据预处理
x_train = x_train.reshape(-1, 28, 28, 1) / 255.0
x_test = x_test.reshape(-1, 28, 28, 1) / 255.0# 构建模型
model = keras.Sequential([layers.Conv2D(filters=32, kernel_size=(3, 3), activation='relu', input_shape=(28, 28, 1)),layers.MaxPooling2D(pool_size=(2, 2)),layers.Flatten(),layers.Dense(10, activation='softmax')
])# 编译模型
model.compile(optimizer='adam',loss='sparse_categorical_crossentropy',metrics=['accuracy'])# 训练模型
model.fit(x_train, y_train, epochs=5, validation_data=(x_test, y_test))# 评估模型
test_loss, test_acc = model.evaluate(x_test, y_test)
print('Test accuracy:', test_acc)
ResNet18:
import tensorflow as tf
from keras import layers, models, datasets
import os# 定义gpu
os.environ['CUDA_VISIBLE_DEVICES'] = '0' # 指定GPU编号
gpus = tf.config.experimental.list_physical_devices('GPU')
if gpus:try:tf.config.experimental.set_memory_growth(gpus[0], True) # 动态申请显存except RuntimeError as e:print(e)# 加载数据集
(train_images, train_labels), (test_images, test_labels) = datasets.cifar10.load_data()# 数据预处理
train_images, test_images = train_images / 255.0, test_images / 255.0# 搭建残差模块
def resnet_block(inputs, num_filters=16, kernel_size=3, strides=1, activation='relu'):x = layers.Conv2D(num_filters, kernel_size=kernel_size, strides=strides, padding='same')(inputs)x = layers.BatchNormalization()(x)if activation:x = layers.Activation(activation)(x)return x# 定义resnet
def resnet18():inputs = layers.Input(shape=(32, 32, 3))num_filters = 64t = layers.BatchNormalization()(inputs)t = resnet_block(t, num_filters=num_filters)for i in range(2):t = resnet_block(t, num_filters=num_filters, activation=None)t = layers.Add()([t, layers.Activation('relu')(t)])t = resnet_block(t, num_filters=num_filters * 2, strides=2, activation=None)t = layers.Add()([t, resnet_block(t, num_filters=num_filters * 2)])num_filters *= 2for i in range(2):t = resnet_block(t, num_filters=num_filters, activation=None)t = layers.Add()([t, layers.Activation('relu')(t)])t = resnet_block(t, num_filters=num_filters * 2, strides=2, activation=None)t = layers.Add()([t, resnet_block(t, num_filters=num_filters * 2)])num_filters *= 2for i in range(2):t = resnet_block(t, num_filters=num_filters, activation=None)t = layers.Add()([t, layers.Activation('relu')(t)])t = layers.AveragePooling2D()(t)outputs = layers.Dense(10, activation='softmax')(layers.Flatten()(t))model = models.Model(inputs, outputs)return model# 定义模型
model = resnet18()
model.compile(optimizer='adam',loss='sparse_categorical_crossentropy',metrics=['accuracy'])# 训练 CPU
# history = model.fit(train_images, train_labels, epochs=10,
# validation_data=(test_images, test_labels))with tf.device('GPU:0'): # 指定使用GPUhistory = model.fit(train_images, train_labels, epochs=10,validation_data=(test_images, test_labels))
相关文章:
Tensorflow2.0:CNN、ResNet实现MNIST分类识别
以下仅是个人的学习笔记 ,内容可能是错误 CNN: import tensorflow as tf from tensorflow import keras from tensorflow.keras import layers# 导入数据 (x_train, y_train), (x_test, y_test) keras.datasets.mnist.load_data()# 数据预处理 x_tra…...
本地jar导入maven
一、通过dependency引入 1.1. jar包放置,建造lib目录 1.2. pom.xml文件 <dependency><groupId>zip4j</groupId><artifactId>zip4j</artifactId><version>1.3.2</version><!--system,类似provided&#x…...
数据结构与算法【堆】的Java实现
前言 之前已经说过堆的特点了,具体文章在数据结构与算法【队列】的Java实现-CSDN博客。因此直接实现堆的其他功能。 建堆 所谓建堆,就是将一个初始的堆变为大顶堆或是小顶堆。这里以大顶堆为例。展示如何建堆。 找到最后一个非叶子节点从后向前&…...
同创永益联合红帽打造一站式数字韧性解决方案
随着AI技术的快速兴起,IT技术已成为推动业务持续增长的重要驱动力,这要求企业不断尝试新事物,改变固有流程,加强IT技术与业务的合作,同时提升数字韧性能力,以实现业务目标。10月26日,红帽2023中…...
c++ call_once 使用详解
c call_once 使用详解 std::call_once 头文件 #include <mutex>。 函数原型: template<class Callable, class... Args> void call_once(std::once_flag& flag, Callable&& f, Args&&... args);flag:标志对象…...
【rosrun diagnostic_analysis】报错No module named rospkg | ubuntu 20.04
ubuntu20.04使用指令报错 现象 rosrun diagnostic_analysis export_csv.py my.bag -d ~/Desktop报错 Traceback (most recent call last): File "/opt/ros/noetic/lib/diagnostic_analysis/export_csv.py", line 40, in <module> import roslib; roslib.load_m…...
高防CDN有什么作用?
分布式拒绝服务攻击(DDoS攻击)是一种针对目标系统的恶意网络攻击行为,DDoS攻击经常会导致被攻击者的业务无法正常访问,也就是所谓的拒绝服务。 常见的DDoS攻击包括以下几类: 网络层攻击:比较典型的攻击类…...
盛元广通开放实训室管理系统2.0
开放实训室管理系统是一种基于网络和数据库的实训室信息管理系统,旨在提高实训室的管理水平,实现实训资源的优化配置和高效利用。该系统通常包括用户管理、设备管理、课程管理、考核管理等功能模块,能够实现实训室的预约、设备借用、课程安排…...
3D建模基础教程:编辑多边形功能命令快捷方式
一、打开3D软件并创建新模型 首先,打开你的3D建模软件,比如Blender、Maya或3ds Max。然后,创建一个新的3D模型。你可以使用基本几何体来创建模型,也可以导入现有的模型。 二、进入编辑多边形模式 在主工具栏中,找到并…...
SaleSmartly新增AI意图识别触发器!让客户享受更精准的自动化服务
AI意图识别技术是对话式AI中很重要的组成部分,通俗点来说就是一种可以识别用户在对话中表达的意图的技术。通过对大量数据的分析和学习,AI可以理解用户想要获得的信息,并根据这些信息来采取相应的行动或提供相应的响应。而在对话式AI中&#…...
计算机毕业设计选题推荐-个人博客微信小程序/安卓APP-项目实战
✨作者主页:IT毕设梦工厂✨ 个人简介:曾从事计算机专业培训教学,擅长Java、Python、微信小程序、Golang、安卓Android等项目实战。接项目定制开发、代码讲解、答辩教学、文档编写、降重等。 ☑文末获取源码☑ 精彩专栏推荐⬇⬇⬇ Java项目 Py…...
一篇详解,Postman设置token依赖步骤
前言 postman做接口测试时,大多数的接口必须在有token的情况下才能运行,我们可以获取token后设置一个环境变量供所在同一个集合中的所有接口使用。 一般是通过调用登录接口,获取到token的值 实战项目:jeecg boot项目 项目官网…...
音频录制实现 绘制频谱
思路 获取设备信息 获取录音的频谱数据 绘制频谱图 具体实现 封装 loadDevices.js /*** 是否支持录音*/ const recordingSupport () > {const scope navigator.mediaDevices || {};if (!scope.getUserMedia) {scope navigatorscope.getUserMedia || (scope.getUserM…...
nginx代理本地服务请求,避免跨域;前端图片压缩并上传
痛点 有时用vscode进行一些测试 请求不同端口服务、或者其他服务接口时时,老是会报跨域,非常的烦 所有就想用 nginx 进行请求代理,来解决这个痛点 nginx 下载地址:nginx: download 下载到某一目录: window下nginx相关…...
Vue3-readonly(深只读) 与 shallowReadonly(浅只读)
Vue3-readonly(深只读) 与 shallowReadonly(浅只读) readonly(深只读):具有响应式对象中所有的属性,其所有值都是只读且不可修改的。shallowReadonly(浅只读):具有响应式对象的第一层属性值是只读且不可修改的&#x…...
中小企业怎么实现数字化转型?有什么实用的工单管理系统?
当前,世界经济数字化转型已是大势所趋。在这个数字化转型的大潮中,如果企业仍然逆水而行,不随大流,那么,企业将有可能会被抛弃,被对手超越,甚至被市场边缘化,导致最终的结果是&#…...
vue3.x中父组件添加自定义参数后,如何获取子组件$emit传递过来的参数
之前写过一篇文章,vue中父组件添加自定义参数后,如何获取子组件$emit传递过来的参数 现在已经进入vue3.x开发的时代了,那么vue3.x中父组件添加自定义参数后,如何获取子组件$emit传递过来的参数? 1、子组件使用emit传…...
【Machine Learning in R - Next Generation • mlr3】
本篇主要介绍mlr3包的基本使用。 一个简单的机器学习流程在mlr3中可被分解为以下几个部分: 创建任务 比如回归、分裂、生存分析、降维、密度任务等等挑选学习器(算法/模型) 比如随机森林、决策树、SVM、KNN等等训练和预测 创建任务 本次示…...
CorelDraw2024(CDR)- 矢量图制作软件介绍
在当今数字化时代,平面设计已成为营销、品牌推广和创意表达中不可或缺的元素。平面设计必备三大软件Adebo PhotoShop、CorelDraw、Adobe illustrator, 今天小编就详细介绍其中之一的CorelDraw软件。为什么这款软件在设计界赢得了声誉,并成为了设计师的无…...
RT-DETR优化改进:轻量级Backbone改进 | VanillaNet极简神经网络模型 | 华为诺亚2023
🚀🚀🚀本文改进:一种极简的神经网络模型 VanillaNet,支持vanillanet_5, vanillanet_6, vanillanet_7, vanillanet_8, vanillanet_9, vanillanet_10, vanillanet_11等版本,相对于自带的rtdetr-l、rtdetr-x参数量如下: layersparametersgradientsvanillanet_5338277174…...
熬过漫漫长夜,终见微光入怀
民宿刘姐我扎根浙东深山,经营一方山间小院,至今已是六个春秋。回望这六七年来的创业之路,那些彻夜难眠的深夜、压垮身心的重担、前路迷茫的无助与煎熬,依旧刻骨铭心,仿佛一切就发生在昨日。最初怀揣对山野生活的赤诚与…...
Insomnia终极指南:构建高效API测试与协作的完整工作流
Insomnia终极指南:构建高效API测试与协作的完整工作流 【免费下载链接】insomnia The open-source, cross-platform API client for GraphQL, REST, WebSockets, SSE and gRPC. With Cloud, Local and Git storage. 项目地址: https://gitcode.com/gh_mirrors/in/…...
学习笔记·敏捷开发
“嗨,阿米戈!” “嗨,比拉博!” “今天我要给大家讲讲程序通常是怎么开发的。” “在 20 世纪,当现代 IT 还处于起步阶段时,每个人似乎都认为编程就像建筑或制造。” “事情通常是这样的:” “客户会解释他需要的程序类型——它应该做什么以及应该如何做。” “业…...
代数拓扑运算流程
文章目录0、背景一、标准计算流程:以单纯同调为例空间剖分,构建单纯复形生成各维度链群定义边界算子定义闭链群与边缘链群计算同调群并解读拓扑信息推导最终拓扑结论二、其他核心概念的典型计算逻辑0、背景 之前为了做一个东西学习TDA&…...
2026 河北 GEO 优化服务商测评:理性看实力,盘古开物AI智推适配才是硬道理
覆盖石家庄、唐山、保定、邯郸、邢台,立足华北,辐射全国,不搞噱头,只讲真实能力随着生成式 AI 全面融入商业营销,GEO 优化已经从河北企业的可选服务,变成抢占区域流量、提升线上可见度的重要方式。尤其制造…...
为内部ai工具平台选择统一api网关时taotoken的接入与管理价值
🚀 告别海外账号与网络限制!稳定直连全球优质大模型,限时半价接入中。 👉 点击领取海量免费额度 为内部AI工具平台选择统一API网关时Taotoken的接入与管理价值 当公司内部需要构建一个集成多种AI能力的工具平台时,技术…...
Unity热更新本质与分层设计原理
1. 热更新不是“打补丁”,而是游戏生命周期的呼吸系统很多人第一次听说“Unity热更新”,脑子里立刻蹦出一个画面:玩家正在打Boss,突然弹出“检测到新版本,正在后台下载……3秒后重启生效”。然后下意识觉得——这不就是…...
2026年转型风口:理发店转战植物染发,能占据市场前10%吗?
2026年,理发店转型的风口已经悄然来临。据数据显示,植物染发和养护市场增速保持在15%以上,而白发脱发人群的比例不断增大,这无疑给众多理发店提供了巨大的转型机会。本文将通过具体的数据、案例和观点,探讨理发店转型植…...
Agent-S3技术深度解析:首个超越人类性能的GUI智能体架构演进与应用实践
Agent-S3技术深度解析:首个超越人类性能的GUI智能体架构演进与应用实践 【免费下载链接】Agent-S Agent S: an open agentic framework that uses computers like a human 项目地址: https://gitcode.com/GitHub_Trending/ag/Agent-S Agent-S3作为首个在OSWo…...
洛雪音乐音源配置完全指南:免费搭建个人音乐库的终极方案
洛雪音乐音源配置完全指南:免费搭建个人音乐库的终极方案 【免费下载链接】lxmusic- lxmusic(洛雪音乐)全网最新最全音源 项目地址: https://gitcode.com/gh_mirrors/lx/lxmusic- 洛雪音乐作为一款强大的音乐播放工具,提供了全网最新最全的音源资…...
