当前位置: 首页 > news >正文

04 卷积神经网络搭建

一、数据集

MNIST数据集是从NIST的两个手写数字数据集:Special Database 3 和Special Database 1中分别取出部分图像,并经过一些图像处理后得到的[参考]。

MNIST数据集共有70000张图像,其中训练集60000张,测试集10000张。所有图像都是28×28的灰度图像,每张图像包含一个手写数字。
 

 1.1 准备数据

将数据集分为训练集、验证集和测试集 

#训练集有60000张图片,前5000张图片作为验证集,后55000作为训练集

1. `x_train_all` 和 `y_train_all`:
   - `x_train_all` 包含了完整的训练数据集的图像数据,这些图像用于训练深度学习模型。
   - `y_train_all` 包含了完整的训练数据集的标签,即与 `x_train_all` 中的图像相对应的类别标签。

2. `x_test` 和 `y_test`:
   - `x_test` 包含了测试数据集的图像数据,这些图像用于评估深度学习模型的性能。
   - `y_test` 包含了测试数据集的标签,即与 `x_test` 中的图像相对应的类别标签。

from tensorflow import keras
import tensorflow as tf
import numpy as np
import pandas as pd
import matplotlib.pyplot as pltfashion_mnist = keras.datasets.fashion_mnist
(x_train_all, y_train_all), (x_test, y_test) = fashion_mnist.load_data()
x_valid, x_train = x_train_all[:5000], x_train_all[5000:]
y_valid, y_train = y_train_all[:5000], y_train_all[5000:]print(x_valid.shape, y_valid.shape)
print(x_train.shape, y_train.shape)
print(x_test.shape, y_test.shape)

1.2 标准化

# 标准化
from sklearn.preprocessing import StandardScalerscaler = StandardScaler()
x_train_scaled = scaler.fit_transform(x_train.astype(np.float32).reshape(-1, 1)).reshape(-1, 28, 28, 1)
x_valid_scaled = scaler.transform(x_valid.astype(np.float32).reshape(-1, 1)).reshape(-1, 28, 28, 1)x_test_scaled = scaler.transform(x_test.astype(np.float32).reshape(-1, 1)).reshape(-1, 28, 28, 1)

1.3 数据集

def make_dataset(data, target, epochs, batch_size, shuffle=True):dataset = tf.data.Dataset.from_tensor_slices((data, target))if shuffle:dataset = dataset.shuffle(10000)dataset = dataset.repeat(epochs).batch(batch_size).prefetch(50)return datasetbatch_size = 64
epochs = 20
train_dataset = make_dataset(x_train_scaled, y_train, epochs, batch_size)

1.4 搭建模型

model = keras.models.Sequential()
# 卷积
model.add(keras.layers.Conv2D(filters = 32, kernel_size = 3, padding = 'same',activation='relu',# batch_size, height, width, channelsinput_shape=(28, 28, 1))) # (28, 28, 32)model.add(keras.layers.Conv2D(filters = 32, kernel_size = 3, padding = 'same',activation='relu'))
# 池化
model.add(keras.layers.MaxPool2D()) # (14, 14, 32)model.add(keras.layers.Conv2D(filters = 64, kernel_size = 3, padding = 'same',activation='relu')) # (14, 14, 64)
model.add(keras.layers.Conv2D(filters = 64, kernel_size = 3, padding = 'same',activation='relu'))
# 池化
model.add(keras.layers.MaxPool2D()) # (7, 7, 64)
model.add(keras.layers.Conv2D(filters = 128, kernel_size = 3, padding = 'same',activation='relu')) # (7, 7, 128)
model.add(keras.layers.Conv2D(filters = 128, kernel_size = 3, padding = 'same',activation='relu')) # (7, 7, 128)
# 池化, 向下取整
model.add(keras.layers.MaxPooling2D()) # (3, 3, 128)model.add(keras.layers.Flatten())
model.add(keras.layers.Dense(512, activation='relu'))
model.add(keras.layers.Dense(256, activation='relu'))
model.add(keras.layers.Dense(10, activation='softmax'))model.compile(loss='sparse_categorical_crossentropy',optimizer='adam',metrics=['accuracy'])print(model.summary())

keras.models.Sequential()`是使用 Keras 构建神经网络模型的开始。`keras.models.Sequential()` 创建了一个 Sequential 模型对象,这是 Keras 中的一种常见模型类型。Sequential 模型是一个线性的、层叠的神经网络模型,适用于顺序层的堆叠,其中每一层都是依次添加到模型中

一旦创建了 Sequential 模型,你可以使用 `.add()` 方法来逐层添加神经网络层,从输入层到输出层。每个层都可以通过实例化 Keras 中的层类来创建,例如

`keras.layers.Dense` 用于全连接层,

`keras.layers.Conv2D` 用于卷积层等。

以下是一个简单的例子,演示如何使用 Sequential 模型创建一个简单的前馈神经网络

from tensorflow import keras# 创建一个 Sequential 模型
model = keras.models.Sequential()# 添加输入层和第一个隐藏层
model.add(keras.layers.Input(shape=(input_shape,)))  # 输入层,input_shape 根据你的数据维度定义
model.add(keras.layers.Dense(units=128, activation='relu'))  # 隐藏层1# 添加第二个隐藏层
model.add(keras.layers.Dense(units=64, activation='relu'))  # 隐藏层2# 添加输出层
model.add(keras.layers.Dense(units=num_classes, activation='softmax'))  # 输出层,num_classes 是输出类别的数量# 编译模型
model.compile(optimizer='adam', loss='categorical_crossentropy', metrics=['accuracy'])

上面的代码创建了一个具有两个隐藏层和一个输出层的前馈神经网络模型,并使用了 ReLU 激活函数和 softmax 激活函数。这只是一个简单的示例,你可以根据你的任务和数据来构建更复杂的模型。一旦模型构建完成,你可以使用 `.fit()` 方法来训练模型。

model.compile(loss='sparse_categorical_crossentropy', optimizer='adam', metrics=['accuracy'])

是 Keras 中用于编译深度学习模型的代码,它设置了模型的损失函数、优化器和评估指标。让我为你解释每个参数的含义:

- `loss='sparse_categorical_crossentropy'`: 这里设置了模型的损失函数。`sparse_categorical_crossentropy` 是一种用于多类别分类问题的损失函数。它适用于目标变量是整数形式(类别标签)的情况,而不需要将目标变量进行独热编码(one-hot encoding)。模型的目标是最小化这个损失函数,从而使预测结果尽可能接近真实标签。

- `optimizer='adam'`: 这里设置了优化器,用于模型的参数更新。`'adam'` 是一种常用的优化算法,它基于梯度下降的方法,具有自适应学习率和动量的特性,通常在深度学习中表现良好。优化器的作用是最小化损失函数,从而调整模型的权重和参数以使模型更好地拟合数据。

- `metrics=['accuracy']`: 这里设置了评估指标,用于在模型训练期间监测模型性能。

`['accuracy']` 表示模型在训练期间将计算并输出准确度(accuracy),即正确分类的样本数与总样本数的比率。准确度通常用于分类问题的性能评估。

一旦模型编译完成,你可以使用 `model.fit()` 方法来训练模型,该方法会使用上述设置的损失函数、优化器和评估指标来进行训练。例如:

```python
model.fit(x_train, y_train, epochs=10, validation_data=(x_valid, y_valid))
```

1.5 train

eval_dataset = make_dataset(x_valid_scaled, y_valid, epochs=1, batch_size=32, shuffle=False)history = model.fit(train_dataset, steps_per_epoch=x_train_scaled.shape[0] // batch_size,epochs=10,validation_data=eval_dataset)

 

1.6 模型评估

model.evaluate(eval_dataset)

 

1.7 可视化

def plot_learning_curves(history):pd.DataFrame(history.history).plot(figsize=(8, 5))plt.grid(True)plt.gca().set_ylim(0, 1)plt.show()plot_learning_curves(history)

相关文章:

04 卷积神经网络搭建

一、数据集 MNIST数据集是从NIST的两个手写数字数据集:Special Database 3 和Special Database 1中分别取出部分图像,并经过一些图像处理后得到的[参考]。 MNIST数据集共有70000张图像,其中训练集60000张,测试集10000张。所有图…...

【hadoop运维】running beyond physical memory limits:正确配置yarn中的mapreduce内存

文章目录 一. 问题描述二. 问题分析与解决1. container内存监控1.1. 虚拟内存判断1.2. 物理内存判断 2. 正确配置mapReduce内存2.1. 配置map和reduce进程的物理内存:2.2. Map 和Reduce 进程的JVM 堆大小 3. 小结 一. 问题描述 在hadoop3.0.3集群上执行hive3.1.2的任…...

数据结构--6.5二叉排序树(插入,查找和删除)

目录 一、创建 二、插入 三、删除 二叉排序树(Binary Sort Tree)又称为二叉查找树,它或者是一棵空树,或者是具有下列性质的二叉树: ——若它的左子树不为空,则左子树上所有结点的值均小于它的根结构的值…...

无需公网IP,在家SSH远程连接公司内网服务器「cpolar内网穿透」

文章目录 1. Linux CentOS安装cpolar2. 创建TCP隧道3. 随机地址公网远程连接4. 固定TCP地址5. 使用固定公网TCP地址SSH远程 本次教程我们来实现如何在外公网环境下,SSH远程连接家里/公司的Linux CentOS服务器,无需公网IP,也不需要设置路由器。…...

Java工具类

一、org.apache.commons.io.IOUtils closeQuietly() toString() copy() toByteArray() write() toInputStream() readLines() copyLarge() lineIterator() readFully() 二、org.apache.commons.io.FileUtils deleteDirectory() readFileToString() de…...

makefile之使用函数wildcard和patsubst

Makefile之调用函数 调用makefile机制实现的一些函数 $(function arguments) : function是函数名,arguments是该函数的参数 参数和函数名用空格或Tab分隔,如果有多个参数,之间用逗号隔开. wildcard函数:让通配符在makefile文件中使用有效果 $(wildcard pattern) 输入只有一个参…...

算法通关村第十八关——排列问题

LeetCode46.给定一个没有重复数字的序列,返回其所有可能的全排列。例如: 输入:[1,2,3] 输出:[[1,2,3],[1,3,2],[2,1,3],[2,3,1],[3,1,2],[3,2,1]] 元素1在[1,2]中已经使…...

基于STM32设计的生理监测装置

一、项目功能要求 设计并制作一个生理监测装置,能够实时监测人体的心电图、呼吸和温度,并在LCD液晶显示屏上显示相关数据。 随着现代生活节奏的加快和环境的变化,人们对身体健康的关注程度越来越高。为了及时掌握自身的生理状况&#xff0c…...

Go-Python-Java-C-LeetCode高分解法-第五周合集

前言 本题解Go语言部分基于 LeetCode-Go 其他部分基于本人实践学习 个人题解GitHub连接:LeetCode-Go-Python-Java-C Go-Python-Java-C-LeetCode高分解法-第一周合集 Go-Python-Java-C-LeetCode高分解法-第二周合集 Go-Python-Java-C-LeetCode高分解法-第三周合集 G…...

【前端知识】前端加密算法(base64、md5、sha1、escape/unescape、AES/DES)

前端加密算法 一、base64加解密算法 简介:Base64算法使用64个字符(A-Z、a-z、0-9、、/)来表示二进制数据的64种可能性,将每3个字节的数据编码为4个可打印字符。如果字节数不是3的倍数,将会进行填充。 优点&#xff1…...

leetcode 925. 长按键入

2023.9.7 我的基本思路是两数组字符逐一对比,遇到不同的字符,判断一下typed与上一字符是否相同,不相同返回false,相同则继续对比。 最后要分别判断name和typed分别先遍历完时的情况。直接看代码: class Solution { p…...

[CMake教程] 循环

目录 一、foreach()二、while()三、break() 与 continue() 作为一个编程语言&#xff0c;CMake也少不了循环流程控制&#xff0c;他提供两种循环foreach() 和 while()。 一、foreach() 基本语法&#xff1a; foreach(<loop_var> <items>)<commands> endfo…...

Mojo安装使用初体验

一个声称比python块68000倍的语言 蹭个热度&#xff0c;安装试试 系统配置要求&#xff1a; 不支持Windows系统 配置要求: 系统&#xff1a;Ubuntu 20.04/22.04 LTSCPU&#xff1a;x86-64 CPU (with SSE4.2 or newer)内存&#xff1a;8 GiB memoryPython 3.8 - 3.10g or cla…...

艺术与AI:科技与艺术的完美融合

文章目录 艺术创作的新工具生成艺术艺术与数据 AI与互动艺术虚拟现实&#xff08;VR&#xff09;与增强现实&#xff08;AR&#xff09;机器学习与互动性 艺术与AI的伦理问题结语 &#x1f389;欢迎来到AIGC人工智能专栏~艺术与AI&#xff1a;科技与艺术的完美融合 ☆* o(≧▽≦…...

Android常用的工具“小插件”——Widget机制

Widget俗称“小插件”&#xff0c;是Android系统中一个很常用的工具。比如我们可以在Launcher中添加一个音乐播放器的Widget。 在Launcher上可以添加插件&#xff0c;那么是不是说只有Launcher才具备这个功能呢&#xff1f; Android系统并没有具体规定谁才能充当“Widget容器…...

探索在云原生环境中构建的大数据驱动的智能应用程序的成功案例,并分析它们的关键要素。

文章目录 1. Netflix - 个性化推荐引擎2. Uber - 实时数据分析和决策支持3. Airbnb - 价格预测和优化5. Google - 自然语言处理和搜索优化 &#x1f388;个人主页&#xff1a;程序员 小侯 &#x1f390;CSDN新晋作者 &#x1f389;欢迎 &#x1f44d;点赞✍评论⭐收藏 ✨收录专…...

jupyter 添加中文选项

文章目录 jupyter 添加中文选项1. 下载中文包2. 选择中文重新加载一下&#xff0c;页面就变成中文了 jupyter 添加中文选项 1. 下载中文包 pip install jupyterlab-language-pack-zh-CN2. 选择中文 重新加载一下&#xff0c;页面就变成中文了 这才是设置中文的正解&#xff…...

系列十、Java操作RocketMQ之批量消息

一、概述 RocketMQ可以一次性发送一组消息&#xff0c;那么这一组消息会被当做一个消息进行消费。 二、案例代码 2.1、pom 同系列五 2.2、RocketMQConstant 同系列五 2.3、BatchConsumer package org.star.batch.consumer;import cn.hutool.core.util.StrUtil; import lom…...

leetcode1两数之和

题目&#xff1a; 给定一个整数数组 nums 和一个整数目标值 target&#xff0c;请你在该数组中找出 和为目标值 target 的那 两个 整数&#xff0c;并返回它们的数组下标。 你可以假设每种输入只会对应一个答案。但是&#xff0c;数组中同一个元素在答案里不能重复出现。 你…...

近年GDC服务器分享合集(四): 《火箭联盟》:为免费游玩而进行的扩展

如今&#xff0c;网络游戏采用免费游玩&#xff08;Free to Play&#xff09;加内购的比例要远大于买断制&#xff0c;这是因为前者能带来更低的用户门槛。甚至有游戏为了获取更多的用户&#xff0c;选择把原来的买断制改为免费游玩&#xff0c;一个典型的例子就是最近的网易的…...

《Playwright:微软的自动化测试工具详解》

Playwright 简介:声明内容来自网络&#xff0c;将内容拼接整理出来的文档 Playwright 是微软开发的自动化测试工具&#xff0c;支持 Chrome、Firefox、Safari 等主流浏览器&#xff0c;提供多语言 API&#xff08;Python、JavaScript、Java、.NET&#xff09;。它的特点包括&a…...

解决Ubuntu22.04 VMware失败的问题 ubuntu入门之二十八

现象1 打开VMware失败 Ubuntu升级之后打开VMware上报需要安装vmmon和vmnet&#xff0c;点击确认后如下提示 最终上报fail 解决方法 内核升级导致&#xff0c;需要在新内核下重新下载编译安装 查看版本 $ vmware -v VMware Workstation 17.5.1 build-23298084$ lsb_release…...

BCS 2025|百度副总裁陈洋:智能体在安全领域的应用实践

6月5日&#xff0c;2025全球数字经济大会数字安全主论坛暨北京网络安全大会在国家会议中心隆重开幕。百度副总裁陈洋受邀出席&#xff0c;并作《智能体在安全领域的应用实践》主题演讲&#xff0c;分享了在智能体在安全领域的突破性实践。他指出&#xff0c;百度通过将安全能力…...

大数据学习(132)-HIve数据分析

​​​​&#x1f34b;&#x1f34b;大数据学习&#x1f34b;&#x1f34b; &#x1f525;系列专栏&#xff1a; &#x1f451;哲学语录: 用力所能及&#xff0c;改变世界。 &#x1f496;如果觉得博主的文章还不错的话&#xff0c;请点赞&#x1f44d;收藏⭐️留言&#x1f4…...

Spring Cloud Gateway 中自定义验证码接口返回 404 的排查与解决

Spring Cloud Gateway 中自定义验证码接口返回 404 的排查与解决 问题背景 在一个基于 Spring Cloud Gateway WebFlux 构建的微服务项目中&#xff0c;新增了一个本地验证码接口 /code&#xff0c;使用函数式路由&#xff08;RouterFunction&#xff09;和 Hutool 的 Circle…...

【Go语言基础【12】】指针:声明、取地址、解引用

文章目录 零、概述&#xff1a;指针 vs. 引用&#xff08;类比其他语言&#xff09;一、指针基础概念二、指针声明与初始化三、指针操作符1. &&#xff1a;取地址&#xff08;拿到内存地址&#xff09;2. *&#xff1a;解引用&#xff08;拿到值&#xff09; 四、空指针&am…...

推荐 github 项目:GeminiImageApp(图片生成方向,可以做一定的素材)

推荐 github 项目:GeminiImageApp(图片生成方向&#xff0c;可以做一定的素材) 这个项目能干嘛? 使用 gemini 2.0 的 api 和 google 其他的 api 来做衍生处理 简化和优化了文生图和图生图的行为(我的最主要) 并且有一些目标检测和切割(我用不到) 视频和 imagefx 因为没 a…...

【LeetCode】3309. 连接二进制表示可形成的最大数值(递归|回溯|位运算)

LeetCode 3309. 连接二进制表示可形成的最大数值&#xff08;中等&#xff09; 题目描述解题思路Java代码 题目描述 题目链接&#xff1a;LeetCode 3309. 连接二进制表示可形成的最大数值&#xff08;中等&#xff09; 给你一个长度为 3 的整数数组 nums。 现以某种顺序 连接…...

django blank 与 null的区别

1.blank blank控制表单验证时是否允许字段为空 2.null null控制数据库层面是否为空 但是&#xff0c;要注意以下几点&#xff1a; Django的表单验证与null无关&#xff1a;null参数控制的是数据库层面字段是否可以为NULL&#xff0c;而blank参数控制的是Django表单验证时字…...

关于easyexcel动态下拉选问题处理

前些日子突然碰到一个问题&#xff0c;说是客户的导入文件模版想支持部分导入内容的下拉选&#xff0c;于是我就找了easyexcel官网寻找解决方案&#xff0c;并没有找到合适的方案&#xff0c;没办法只能自己动手并分享出来&#xff0c;针对Java生成Excel下拉菜单时因选项过多导…...