深度学习笔记(七)——基于Iris/MNIST数据集构建基础的分类网络算法实战
文中程序以Tensorflow-2.6.0为例
部分概念包含笔者个人理解,如有遗漏或错误,欢迎评论或私信指正。
截图和程序部分引用自北京大学机器学习公开课
认识网络的构建结构
在神经网络的构建过程中,都避不开以下几个步骤:
- 导入网络和依赖模块
- 原始数据处理和清洗
- 加载训练和测试数据
- 构建网络结构,确定网络优化方法
- 将数据送入网络进行训练,同时判断预测效果
- 保存模型
- 部署算法,使用新的数据进行预测推理
使用Keras快速构建网络的必要API
在tensorflow2版本中将很多基础函数进行了二次封装,进一步急速了算法初期的构建实现。通过keras提供的很多高级API可以在较短的代码体量上实现网络功能。同时通过搭配tf中的基础功能函数可以实现各种不同类型的卷积和组合操作。正是这中高级API和底层元素及的操作大幅度的提升了tensorflow的自由程度和易用性。
常用网络
全连接层
tf.keras.layers.Dense(units=3, activation=tf.keras.activations.softmax, kernel_regularizer=tf.keras.regularizers.L2())
units:维数(神经元个数)
activation:激活函数,可选:relu softmax sigmoid tanh,这里记不住的话可以用tf.keras.activations.逐个查看
kernel_regularizer:正则化函数,同样的可以使用tf.keras.regularizers.逐个查看
全连接层是标准的神经元组成,更多被用在网络的后端或解码端(Decoder)用来输出预测数据。
拉伸层(维度展平)
tf.keras.layers.Flatten()
这个函数默认不需要输入参数,直接使用,它会将多维的数据按照每一行依次排开首尾连接变成一个一维的张量。通常在数据输入到全连接层之前使用。
卷积层
tf.keras.layers.Conv2D(filters=3, kernel_size=3, strides=1, padding='valid')
filters:卷积核个数
kernel_size:卷积核尺寸
strides:卷积核步长,卷积核是在原始数据上滑动遍历完成数据计算。
padding:可填 ‘valid’ ‘same’,是否使用全零填充,影响最后卷积结果的大小。
卷积一般被用来提取数据的数据特征。卷积最关键的就是卷积核个数和卷积核尺寸。假设输入一个1nn大小的张量,经过x个卷积核+步长为2+尺寸可以整除n的卷积层之后会输出一个x*(n/2)*(n/2)大小的张量。可以理解为卷积步长和卷积核大小影响输出张量的长宽,卷积核的大小影响输出张量的深度。
构建网络
使用Sequential构建简单网络,或者构建网络模块。列表中顺序包含网络的各个层。
tf.keras.models.Sequential([ ])
使用独立的class构建,这里定义一个类继承自 tensorflow.keras.Model 后面基本是标准结构>初始化相关参数>定义网络层>重写call函数定义前向传播层的连接顺序。后续随着使用的深入可以进一步的添加更多函数来实现不同类型的网络。
class mynnModel(Model): # 继承from tensorflow.keras import Model 作为父类def __init__(self):super(IrisModel, self).__init__() # 初始化父类的参数self.d1 = layers.Dense(units=3, activation=tf.keras.activations.softmax, kernel_regularizer=tf.keras.regularizers.L2())def call(self, input): # 重写前向传播函数y = self.d1(input)return ymodel = IrisModel()
训练及其参数设置
设置训练参数
tensorflow.keras.Model.compile(optimizer=参数更新优化器,loss=损失函数metrics=准确率计算方式,即输出数据类型和标签数据类型如何对应)
具体参数可以看下面的内容:
optimizer:参数优化器 SGD: tf.keras.optimizers.SGD(learning_rate=0.1,momentum=动量参数) learning_rate学习率,momentum动量参数AdaGrad: tf.keras.optimizers.Adagrad(learning_rate=学习率)Adam: tf.keras.optimizers.Adam(learning_rate=学习率 , beta_1=0.9, beta_2=0.999)
loss:损失函数MSE: tf.keras.losses.MeanSquaredError()交叉熵损失: tf.keras.losses.SparseCategoricalCrossentropy(from_logits=False) from_logits=true时输出值经过一次softmax概率归一化
metrics:准确率计算方式,就是输出数据类型和标签数据类型如何对应数值型(两个都是序列值): 'accuracy'都是独热码: 'categorical_accuracy'标签是数值,输出是独热码: 'sparse_categorical_accuracy'
训练
tensorflow.keras.Model.model.fit(x_train, y_train, batch_size=32, epochs=500, validation_split=0.2, validation_freq=20)
网络传入参数含义如下:
输入的数据依次为:输入训练特征数据,标签数据,单次输入数据量,迭代次数
validation_split=从训练集划分多少比例数据用来测试 / validation_data=(测试特征数据,测试标签数据) 这两个参数智能二选一
validation_freq=多少次epoch测试一次
输出网络信息
tensorflow.keras.Model.model.summary()
上面这个函数可以在训练结束或者训练开始之前输出一次网络的结构信息用于确认。
实际应用展示
环境
软件环境的配置可以查看环境配置流程说明
cuda = 11.8 # CUDA也可以使用11.2版本
python=3.7
numpy==1.19.5
matplotlib== 3.5.3
notebook==6.4.12
scikit-learn==1.2.0
tensorflow==2.6.0
keras==2.6.0
使用iris数据集构建基础的分类网络
import tensorflow as tf
from sklearn import datasets
import numpy as npx_train = datasets.load_iris().data
y_train = datasets.load_iris().targetnp.random.seed(116)
np.random.shuffle(x_train)
np.random.seed(116)
np.random.shuffle(y_train)
tf.random.set_seed(116)model = tf.keras.models.Sequential([ tf.keras.layers.Dense(3, activation='softmax',kernel_regularizer=tf.keras.regularizers.l2())])
model.compile(optimizer=tf.keras.optimizers.SGD(learning_rate=0.1),loss=tf.keras.losses.SparseCategoricalCrossentropy(from_logits=False),metrics=['sparse_categorical_accuracy'])
model.fit(x_train, y_train, batch_size=32, epochs=500, validation_split=0.2, validation_freq=20)
model.summary( )
通过上面这样几行简单的代码,我们实现了对iris数据的分类训练。在上面的代码中使用了Sequential函数来构建网络。
使用MNIST数据集设计分类网络
在开始下面的代码之前,要先下载对应的数据 https://storage.googleapis.com/tensorflow/tf-keras-datasets/mnist.npz 复制这段网址在浏览器打开会直接下载数据,然后将下载好的mnist.npz复制到一个新的路径下,然后在tf.keras.datasets.mnist.load_data(path=‘you file path ’)代码中的这行里修改为你的路径,注意要使用绝对路径。
import tensorflow as tf
from tensorflow.keras import Model
from tensorflow.keras import layers
from sklearn import datasets
import numpy as np
import matplotlib.pyplot as plt(x_train, y_train), (x_test, y_test) = tf.keras.datasets.mnist.load_data(path='E:\Tensorflow\data\mnist.npz') # 注意替换自己的使用绝对路径
x_train, x_test = x_train/255.0, x_test/255.0 # 图像数据归一化
print('训练集样本的大小:', x_train.shape)
print('训练集标签的大小:', y_train.shape)
print('测试集样本的大小:', x_test.shape)
print('测试集标签的大小:', y_test.shape)
#可视化样本,下面是输出了训练集中前20个样本
fig, ax = plt.subplots(nrows=4,ncols=5,sharex='all',sharey='all')
ax = ax.flatten()
for i in range(20):img = x_train[i].reshape(28, 28)ax[i].imshow(img,cmap='Greys')
ax[0].set_xticks([])
ax[0].set_yticks([])
plt.tight_layout()
plt.show()
# 定义网络结构
class mnisModel(Model):def __init__(self, *args, **kwargs):super(mnisModel, self).__init__(*args, **kwargs)self.flatten1=layers.Flatten()self.d1=layers.Dense(128, activation=tf.keras.activations.relu)self.d2=layers.Dense(10, activation=tf.keras.activations.softmax)def call(self, input):x = self.flatten1(input)x = self.d1(x)x = self.d2(x)return(x)
model = mnisModel()
#设置训练参数
model.compile(optimizer='adam', # 'adam' tf.keras.optimizers.Adam(learning_rate=0.4 , beta_1=0.9, beta_2=0.999)loss=tf.keras.losses.SparseCategoricalCrossentropy(from_logits=False),metrics=['sparse_categorical_accuracy'])
# 训练
model.fit(x_train, y_train, batch_size=32, epochs=5, validation_data = (x_test, y_test), validation_freq=1)
model.summary()
运行后会先显示数据集中的前二十个数字

关闭数字展示窗口后开始训练,并看到训练的过程

相关文章:
深度学习笔记(七)——基于Iris/MNIST数据集构建基础的分类网络算法实战
文中程序以Tensorflow-2.6.0为例 部分概念包含笔者个人理解,如有遗漏或错误,欢迎评论或私信指正。 截图和程序部分引用自北京大学机器学习公开课 认识网络的构建结构 在神经网络的构建过程中,都避不开以下几个步骤: 导入网络和依…...
Windows启动MongoDB服务报错(错误 1053:服务没有及时响应启动或控制请求)
问题描述:修改MongoDB服务bin目录下的mongod.cfg,然后在任务管理器找到MongoDB服务-->右键-->点击【开始】,启动失败无提示: 右键点击任务管理器的MongoDB服务-->点击【打开服务】,跳转到服务页面-->找到M…...
Android Framework 常见解决方案(25-2)定制CPUSET解决方案-system修改及编译部分调整
1 原理说明 这个方案有如下基本需求: 构建自定义CPUSET,/dev/cpuset中包含一个全新的cpuset分组。且可以通过set_cpuset_policy和set_sched_policy接口可以设置自定义CPUSET。开机启动后可以通过zygote判定来对特定的应用进程设置CPUSET,并…...
OpenAI推出GPT商店和ChatGPT Team服务
🦉 AI新闻 🚀 OpenAI推出GPT商店和ChatGPT Team服务 摘要:OpenAI正式推出了其GPT商店和ChatGPT Team服务。用户已经创建了超过300万个ChatGPT自定义版本,并分享给其他人使用。GPT商店集结了用户为各种任务创建的定制化ChatGPT&a…...
3D建模素材分层渲染怎么操作?
在3D建模素材分层渲染过程中,需要将场景中的元素分到不同的层里,然后分别进行渲染。以下是一个简单的方法: 1、打开要渲染的3D建模素材。 2、在场景中选择要分层的元素,然后在软件的图层面板中新建图层,将元素拖拽到新…...
SAICP(模拟退火迭代最近点)的实现
SAICP(模拟退火迭代最近点)的实现 注: 本系列所有文章在github开源, 也是我个人的学习笔记, 欢迎大家去star以及fork, 感谢! 仓库地址: pointcloud-processing-visualization 总结一下上周的学习情况 ICP会存在局部最小值的问题, 这个问题可能即使是没有实际遇到过, 也或多…...
FineBI实战项目一(23):订单商品分类词云图分析开发
点击新建组件,创建订单商品分类词云图组件。 选择词云,拖拽catName到颜色和文本,拖拽cat到大小。 将组件拖拽到仪表板。 结果如下:...
DOS命令
当使用DOS命令时,可以在命令提示符下输入各种命令以执行不同的任务。以下是一些常见DOS命令的详细说明: dir (Directory): 列出当前目录中的文件和子目录。 用法: dir [drive:][path][filename] [/p] [/w] cd (Change Directory): 更改当前目录。 用法: …...
【Python】torch中的.detach()函数详解和示例
在PyTorch中,.detach()是一个用于张量的方法,主要用于创建该张量的一个“离断”版本。这个方法在很多情况下都非常有用,例如在缓存释放、模型评估和简化计算图等场景中。 .detach()方法用于从计算图中分离一个张量,这意味着它创建…...
二级域名分发系统源码 对接易支付php源码 全开源
全面开源的易支付PHP源码分享:实现二级域名分发对接 首先,在epay的config.php文件中修改您的支付域名。 随后,在二级域名分发网站上做相应修改。 伪静态 location / { try_files $uri $uri/ /index.php?$query_string; } 源码下载&#…...
二分查找与搜索树的高频问题(算法村第九关白银挑战)
基于二分查找的拓展问题 山峰数组的封顶索引 852. 山脉数组的峰顶索引 - 力扣(LeetCode) 给你由整数组成的山脉数组 arr ,返回满足 arr[0] < arr[1] < ... arr[i - 1] < arr[i] > arr[i 1] > ... > arr[arr.length - 1…...
Python爬虫快速入门
Python 爬虫Sutdy 1.基本类库 request(请求) 引入 from urllib import request定义url路径 url"http://www.baidu.com"进行请求,返回一个响应对象response responserequest.urlopen(url)读取响应体read()以字节形式打印网页源码 response.read()转码 编码 文本–by…...
部署MinIO
一、安装部署MINIO 1.1 下载 wget https://dl.min.io/server/minio/release/linux-arm64/minio chmod x minio mv minio /usr/local/bin/ # 控制台启动可参考如下命令, 守护进程启动请看下一个代码块 # ./minio server /data /data --console-address ":9001"1.2 配…...
RK3566环境搭建
环境:vmware16,ubuntu 18.04 安装依赖库: sudo apt-get install repo git ssh make gcc libssl-dev liblz4-tool expect g patchelf chrpath gawk texinfo chrpath diffstat binfmt-support qemu-user-static live-build bison flex fakero…...
精确掌控并发:滑动时间窗口算法在分布式环境下并发流量控制的设计与实现
这是《百图解码支付系统设计与实现》专栏系列文章中的第(15)篇,也是流量控制系列的第(2)篇。点击上方关注,深入了解支付系统的方方面面。 上一篇介绍了固定时间窗口算法在支付渠道限流的应用以及使用redis…...
Python展示 RGB立方体的二维切面视图
代码实现 import numpy as np import matplotlib.pyplot as plt# 生成 24-bit 全彩 RGB 立方体 def generate_rgb_cube():# 初始化一个 256x256x256 的三维数组rgb_cube np.zeros((256, 256, 256, 3), dtypenp.uint8)# 填充立方体for r in range(256):for g in range(256):fo…...
03 顺序表
目录 线性表顺序表练习 线性表(Linear list)是n个具有相同特性的数据元素的有限序列。线性表是一种在实际中广泛使用的数据结构,常见的线性表:顺序表、链表、栈、队列、字符串。。。 线性表在逻辑上时线性结构,是连续的一条直线。但在物理结…...
2023年全球软件开发大会(QCon北京站2023)9月:核心内容与学习收获(附大会核心PPT下载)
随着科技的飞速发展,全球软件开发大会(QCon)作为行业领先的技术盛会,为世界各地的专业人士提供了交流与学习的平台。本次大会汇集了全球的软件开发者、架构师、项目经理等,共同探讨软件开发的最新趋势、技术与实践。本…...
ChatGPT 和 文心一言 的优缺点及需求和使用场景
ChatGPT和文心一言是两种不同的自然语言生成模型,它们有各自的优点和缺点。 ChatGPT(Generative Pre-trained Transformer)是由OpenAI开发的生成式AI模型,它在庞大的文本数据集上进行了预训练,并可以根据输入生成具有上…...
架构师之超时未支付的订单进行取消操作的几种解决方案
今天给大家上一盘硬菜,并且是支付中非常重要的一个技术解决方案,有这块业务的同学注意自己尝试一把哈! 一、需求如下: 生成订单30分钟未支付,自动取消 生成订单60秒后,给用户发短信 对上述的需求,我们给…...
XXMI启动器:二次元游戏模组统一管理平台完整指南
XXMI启动器:二次元游戏模组统一管理平台完整指南 【免费下载链接】XXMI-Launcher Modding platform for GI, HSR, WW and ZZZ 项目地址: https://gitcode.com/gh_mirrors/xx/XXMI-Launcher 还在为多款二次元游戏模组管理而烦恼吗?XXMI启动器为你提…...
原神帧率解锁完整指南:5步突破60帧限制,体验丝滑游戏画面
原神帧率解锁完整指南:5步突破60帧限制,体验丝滑游戏画面 【免费下载链接】genshin-fps-unlock unlocks the 60 fps cap 项目地址: https://gitcode.com/gh_mirrors/ge/genshin-fps-unlock 对于追求极致流畅游戏体验的《原神》玩家来说࿰…...
vLLM推理引擎教程8-CUDA Graph内存池优化
1. CUDA Graph内存池优化原理 在vLLM这类大模型推理引擎中,CUDA Graph技术已经成为提升性能的标配方案。但很多开发者在使用过程中会遇到一个棘手问题:当需要处理不同batch size的请求时,显存碎片和重复分配会导致性能下降。这时候就需要引入…...
学术论文利器:使用LaTeX撰写cv_unet_image-colorization技术报告与实验图表
学术论文利器:使用LaTeX撰写cv_unet_image-colorization技术报告与实验图表 写技术报告或者论文,尤其是涉及图像处理、深度学习这类需要大量公式和图表的领域,你是不是也遇到过这些烦恼?用Word排版,公式稍微复杂一点就…...
Linux中的more 和 less区别对比分析
在 Linux/Unix 系统中,more 和 less 都是用于分页查看文本文件的命令,但 less 是 more 的增强版,功能更强大。以下是它们的核心区别和用法对比:1. 基础功能对比特性moreless(更强大)向前翻页❌ 仅支持向下翻…...
Mac开发者必备:OpenClaw+Qwen3.5-9B自动化测试流水线
Mac开发者必备:OpenClawQwen3.5-9B自动化测试流水线 1. 为什么开发者需要本地化CI/CD工具 作为一名长期在Mac上开发的全栈工程师,我一直在寻找一种轻量级的自动化测试方案。传统的Jenkins或GitHub Actions虽然强大,但对于个人项目和小团队来…...
Helm与Vault整合的实践之旅
在容器化和微服务架构的今天,管理配置文件和敏感信息变得愈发重要。使用Helm进行应用部署时,结合Vault来管理和注入机密信息是一个很好的实践。本文将通过一个实际的例子,详细说明如何在Helm Chart中使用Vault来配置和注入机密信息。 背景 Helm是一个包管理工具,可以帮助…...
SmallThinker-3B开源镜像实操:边缘部署+草稿加速双场景落地指南
SmallThinker-3B开源镜像实操:边缘部署草稿加速双场景落地指南 1. 引言:为什么你需要关注SmallThinker-3B? 如果你正在寻找一个既能在边缘设备上流畅运行,又能作为大模型“加速器”的AI工具,那么SmallThinker-3B-Pre…...
Intv_ai_mk11集成Node.js环境配置:快速构建实时聊天应用
Intv_ai_mk11集成Node.js环境配置:快速构建实时聊天应用 1. 环境准备与快速部署 在开始构建实时聊天应用之前,我们需要确保开发环境已经准备就绪。这里假设你已经具备基本的JavaScript和Node.js知识。 首先,确保你的系统已经安装了Node.js…...
节能模式!OpenClaw优化Qwen3-4B模型夜间任务功耗
节能模式!OpenClaw优化Qwen3-4B模型夜间任务功耗 1. 为什么需要关注OpenClaw的能耗问题 去年夏天,我的MacBook Pro在运行OpenClaw执行夜间数据整理任务时,风扇狂转的声音把我从睡梦中吵醒。摸到发烫的机身时,我突然意识到——这…...
