深度学习笔记(七)——基于Iris/MNIST数据集构建基础的分类网络算法实战
文中程序以Tensorflow-2.6.0为例
部分概念包含笔者个人理解,如有遗漏或错误,欢迎评论或私信指正。
截图和程序部分引用自北京大学机器学习公开课
认识网络的构建结构
在神经网络的构建过程中,都避不开以下几个步骤:
- 导入网络和依赖模块
- 原始数据处理和清洗
- 加载训练和测试数据
- 构建网络结构,确定网络优化方法
- 将数据送入网络进行训练,同时判断预测效果
- 保存模型
- 部署算法,使用新的数据进行预测推理
使用Keras快速构建网络的必要API
在tensorflow2版本中将很多基础函数进行了二次封装,进一步急速了算法初期的构建实现。通过keras提供的很多高级API可以在较短的代码体量上实现网络功能。同时通过搭配tf中的基础功能函数可以实现各种不同类型的卷积和组合操作。正是这中高级API和底层元素及的操作大幅度的提升了tensorflow的自由程度和易用性。
常用网络
全连接层
tf.keras.layers.Dense(units=3, activation=tf.keras.activations.softmax, kernel_regularizer=tf.keras.regularizers.L2())
units:维数(神经元个数)
activation:激活函数,可选:relu softmax sigmoid tanh,这里记不住的话可以用tf.keras.activations.
逐个查看
kernel_regularizer:正则化函数,同样的可以使用tf.keras.regularizers.
逐个查看
全连接层是标准的神经元组成,更多被用在网络的后端或解码端(Decoder)用来输出预测数据。
拉伸层(维度展平)
tf.keras.layers.Flatten()
这个函数默认不需要输入参数,直接使用,它会将多维的数据按照每一行依次排开首尾连接变成一个一维的张量。通常在数据输入到全连接层之前使用。
卷积层
tf.keras.layers.Conv2D(filters=3, kernel_size=3, strides=1, padding='valid')
filters:卷积核个数
kernel_size:卷积核尺寸
strides:卷积核步长,卷积核是在原始数据上滑动遍历完成数据计算。
padding:可填 ‘valid’ ‘same’,是否使用全零填充,影响最后卷积结果的大小。
卷积一般被用来提取数据的数据特征。卷积最关键的就是卷积核个数和卷积核尺寸。假设输入一个1nn大小的张量,经过x个卷积核+步长为2+尺寸可以整除n的卷积层之后会输出一个x*(n/2)*(n/2)大小的张量。可以理解为卷积步长和卷积核大小影响输出张量的长宽,卷积核的大小影响输出张量的深度。
构建网络
使用Sequential构建简单网络,或者构建网络模块。列表中顺序包含网络的各个层。
tf.keras.models.Sequential([ ])
使用独立的class构建,这里定义一个类继承自 tensorflow.keras.Model
后面基本是标准结构>初始化相关参数>定义网络层>重写call函数定义前向传播层的连接顺序。后续随着使用的深入可以进一步的添加更多函数来实现不同类型的网络。
class mynnModel(Model): # 继承from tensorflow.keras import Model 作为父类def __init__(self):super(IrisModel, self).__init__() # 初始化父类的参数self.d1 = layers.Dense(units=3, activation=tf.keras.activations.softmax, kernel_regularizer=tf.keras.regularizers.L2())def call(self, input): # 重写前向传播函数y = self.d1(input)return ymodel = IrisModel()
训练及其参数设置
设置训练参数
tensorflow.keras.Model.compile(optimizer=参数更新优化器,loss=损失函数metrics=准确率计算方式,即输出数据类型和标签数据类型如何对应)
具体参数可以看下面的内容:
optimizer:参数优化器 SGD: tf.keras.optimizers.SGD(learning_rate=0.1,momentum=动量参数) learning_rate学习率,momentum动量参数AdaGrad: tf.keras.optimizers.Adagrad(learning_rate=学习率)Adam: tf.keras.optimizers.Adam(learning_rate=学习率 , beta_1=0.9, beta_2=0.999)
loss:损失函数MSE: tf.keras.losses.MeanSquaredError()交叉熵损失: tf.keras.losses.SparseCategoricalCrossentropy(from_logits=False) from_logits=true时输出值经过一次softmax概率归一化
metrics:准确率计算方式,就是输出数据类型和标签数据类型如何对应数值型(两个都是序列值): 'accuracy'都是独热码: 'categorical_accuracy'标签是数值,输出是独热码: 'sparse_categorical_accuracy'
训练
tensorflow.keras.Model.model.fit(x_train, y_train, batch_size=32, epochs=500, validation_split=0.2, validation_freq=20)
网络传入参数含义如下:
输入的数据依次为:输入训练特征数据,标签数据,单次输入数据量,迭代次数
validation_split=从训练集划分多少比例数据用来测试 / validation_data=(测试特征数据,测试标签数据) 这两个参数智能二选一
validation_freq=多少次epoch测试一次
输出网络信息
tensorflow.keras.Model.model.summary()
上面这个函数可以在训练结束或者训练开始之前输出一次网络的结构信息用于确认。
实际应用展示
环境
软件环境的配置可以查看环境配置流程说明
cuda = 11.8 # CUDA也可以使用11.2版本
python=3.7
numpy==1.19.5
matplotlib== 3.5.3
notebook==6.4.12
scikit-learn==1.2.0
tensorflow==2.6.0
keras==2.6.0
使用iris数据集构建基础的分类网络
import tensorflow as tf
from sklearn import datasets
import numpy as npx_train = datasets.load_iris().data
y_train = datasets.load_iris().targetnp.random.seed(116)
np.random.shuffle(x_train)
np.random.seed(116)
np.random.shuffle(y_train)
tf.random.set_seed(116)model = tf.keras.models.Sequential([ tf.keras.layers.Dense(3, activation='softmax',kernel_regularizer=tf.keras.regularizers.l2())])
model.compile(optimizer=tf.keras.optimizers.SGD(learning_rate=0.1),loss=tf.keras.losses.SparseCategoricalCrossentropy(from_logits=False),metrics=['sparse_categorical_accuracy'])
model.fit(x_train, y_train, batch_size=32, epochs=500, validation_split=0.2, validation_freq=20)
model.summary( )
通过上面这样几行简单的代码,我们实现了对iris数据的分类训练。在上面的代码中使用了Sequential函数来构建网络。
使用MNIST数据集设计分类网络
在开始下面的代码之前,要先下载对应的数据 https://storage.googleapis.com/tensorflow/tf-keras-datasets/mnist.npz 复制这段网址在浏览器打开会直接下载数据,然后将下载好的mnist.npz复制到一个新的路径下,然后在tf.keras.datasets.mnist.load_data(path=‘you file path ’)代码中的这行里修改为你的路径,注意要使用绝对路径。
import tensorflow as tf
from tensorflow.keras import Model
from tensorflow.keras import layers
from sklearn import datasets
import numpy as np
import matplotlib.pyplot as plt(x_train, y_train), (x_test, y_test) = tf.keras.datasets.mnist.load_data(path='E:\Tensorflow\data\mnist.npz') # 注意替换自己的使用绝对路径
x_train, x_test = x_train/255.0, x_test/255.0 # 图像数据归一化
print('训练集样本的大小:', x_train.shape)
print('训练集标签的大小:', y_train.shape)
print('测试集样本的大小:', x_test.shape)
print('测试集标签的大小:', y_test.shape)
#可视化样本,下面是输出了训练集中前20个样本
fig, ax = plt.subplots(nrows=4,ncols=5,sharex='all',sharey='all')
ax = ax.flatten()
for i in range(20):img = x_train[i].reshape(28, 28)ax[i].imshow(img,cmap='Greys')
ax[0].set_xticks([])
ax[0].set_yticks([])
plt.tight_layout()
plt.show()
# 定义网络结构
class mnisModel(Model):def __init__(self, *args, **kwargs):super(mnisModel, self).__init__(*args, **kwargs)self.flatten1=layers.Flatten()self.d1=layers.Dense(128, activation=tf.keras.activations.relu)self.d2=layers.Dense(10, activation=tf.keras.activations.softmax)def call(self, input):x = self.flatten1(input)x = self.d1(x)x = self.d2(x)return(x)
model = mnisModel()
#设置训练参数
model.compile(optimizer='adam', # 'adam' tf.keras.optimizers.Adam(learning_rate=0.4 , beta_1=0.9, beta_2=0.999)loss=tf.keras.losses.SparseCategoricalCrossentropy(from_logits=False),metrics=['sparse_categorical_accuracy'])
# 训练
model.fit(x_train, y_train, batch_size=32, epochs=5, validation_data = (x_test, y_test), validation_freq=1)
model.summary()
运行后会先显示数据集中的前二十个数字
关闭数字展示窗口后开始训练,并看到训练的过程
相关文章:

深度学习笔记(七)——基于Iris/MNIST数据集构建基础的分类网络算法实战
文中程序以Tensorflow-2.6.0为例 部分概念包含笔者个人理解,如有遗漏或错误,欢迎评论或私信指正。 截图和程序部分引用自北京大学机器学习公开课 认识网络的构建结构 在神经网络的构建过程中,都避不开以下几个步骤: 导入网络和依…...

Windows启动MongoDB服务报错(错误 1053:服务没有及时响应启动或控制请求)
问题描述:修改MongoDB服务bin目录下的mongod.cfg,然后在任务管理器找到MongoDB服务-->右键-->点击【开始】,启动失败无提示: 右键点击任务管理器的MongoDB服务-->点击【打开服务】,跳转到服务页面-->找到M…...
Android Framework 常见解决方案(25-2)定制CPUSET解决方案-system修改及编译部分调整
1 原理说明 这个方案有如下基本需求: 构建自定义CPUSET,/dev/cpuset中包含一个全新的cpuset分组。且可以通过set_cpuset_policy和set_sched_policy接口可以设置自定义CPUSET。开机启动后可以通过zygote判定来对特定的应用进程设置CPUSET,并…...

OpenAI推出GPT商店和ChatGPT Team服务
🦉 AI新闻 🚀 OpenAI推出GPT商店和ChatGPT Team服务 摘要:OpenAI正式推出了其GPT商店和ChatGPT Team服务。用户已经创建了超过300万个ChatGPT自定义版本,并分享给其他人使用。GPT商店集结了用户为各种任务创建的定制化ChatGPT&a…...

3D建模素材分层渲染怎么操作?
在3D建模素材分层渲染过程中,需要将场景中的元素分到不同的层里,然后分别进行渲染。以下是一个简单的方法: 1、打开要渲染的3D建模素材。 2、在场景中选择要分层的元素,然后在软件的图层面板中新建图层,将元素拖拽到新…...
SAICP(模拟退火迭代最近点)的实现
SAICP(模拟退火迭代最近点)的实现 注: 本系列所有文章在github开源, 也是我个人的学习笔记, 欢迎大家去star以及fork, 感谢! 仓库地址: pointcloud-processing-visualization 总结一下上周的学习情况 ICP会存在局部最小值的问题, 这个问题可能即使是没有实际遇到过, 也或多…...

FineBI实战项目一(23):订单商品分类词云图分析开发
点击新建组件,创建订单商品分类词云图组件。 选择词云,拖拽catName到颜色和文本,拖拽cat到大小。 将组件拖拽到仪表板。 结果如下:...
DOS命令
当使用DOS命令时,可以在命令提示符下输入各种命令以执行不同的任务。以下是一些常见DOS命令的详细说明: dir (Directory): 列出当前目录中的文件和子目录。 用法: dir [drive:][path][filename] [/p] [/w] cd (Change Directory): 更改当前目录。 用法: …...
【Python】torch中的.detach()函数详解和示例
在PyTorch中,.detach()是一个用于张量的方法,主要用于创建该张量的一个“离断”版本。这个方法在很多情况下都非常有用,例如在缓存释放、模型评估和简化计算图等场景中。 .detach()方法用于从计算图中分离一个张量,这意味着它创建…...

二级域名分发系统源码 对接易支付php源码 全开源
全面开源的易支付PHP源码分享:实现二级域名分发对接 首先,在epay的config.php文件中修改您的支付域名。 随后,在二级域名分发网站上做相应修改。 伪静态 location / { try_files $uri $uri/ /index.php?$query_string; } 源码下载&#…...

二分查找与搜索树的高频问题(算法村第九关白银挑战)
基于二分查找的拓展问题 山峰数组的封顶索引 852. 山脉数组的峰顶索引 - 力扣(LeetCode) 给你由整数组成的山脉数组 arr ,返回满足 arr[0] < arr[1] < ... arr[i - 1] < arr[i] > arr[i 1] > ... > arr[arr.length - 1…...
Python爬虫快速入门
Python 爬虫Sutdy 1.基本类库 request(请求) 引入 from urllib import request定义url路径 url"http://www.baidu.com"进行请求,返回一个响应对象response responserequest.urlopen(url)读取响应体read()以字节形式打印网页源码 response.read()转码 编码 文本–by…...

部署MinIO
一、安装部署MINIO 1.1 下载 wget https://dl.min.io/server/minio/release/linux-arm64/minio chmod x minio mv minio /usr/local/bin/ # 控制台启动可参考如下命令, 守护进程启动请看下一个代码块 # ./minio server /data /data --console-address ":9001"1.2 配…...

RK3566环境搭建
环境:vmware16,ubuntu 18.04 安装依赖库: sudo apt-get install repo git ssh make gcc libssl-dev liblz4-tool expect g patchelf chrpath gawk texinfo chrpath diffstat binfmt-support qemu-user-static live-build bison flex fakero…...

精确掌控并发:滑动时间窗口算法在分布式环境下并发流量控制的设计与实现
这是《百图解码支付系统设计与实现》专栏系列文章中的第(15)篇,也是流量控制系列的第(2)篇。点击上方关注,深入了解支付系统的方方面面。 上一篇介绍了固定时间窗口算法在支付渠道限流的应用以及使用redis…...

Python展示 RGB立方体的二维切面视图
代码实现 import numpy as np import matplotlib.pyplot as plt# 生成 24-bit 全彩 RGB 立方体 def generate_rgb_cube():# 初始化一个 256x256x256 的三维数组rgb_cube np.zeros((256, 256, 256, 3), dtypenp.uint8)# 填充立方体for r in range(256):for g in range(256):fo…...

03 顺序表
目录 线性表顺序表练习 线性表(Linear list)是n个具有相同特性的数据元素的有限序列。线性表是一种在实际中广泛使用的数据结构,常见的线性表:顺序表、链表、栈、队列、字符串。。。 线性表在逻辑上时线性结构,是连续的一条直线。但在物理结…...

2023年全球软件开发大会(QCon北京站2023)9月:核心内容与学习收获(附大会核心PPT下载)
随着科技的飞速发展,全球软件开发大会(QCon)作为行业领先的技术盛会,为世界各地的专业人士提供了交流与学习的平台。本次大会汇集了全球的软件开发者、架构师、项目经理等,共同探讨软件开发的最新趋势、技术与实践。本…...
ChatGPT 和 文心一言 的优缺点及需求和使用场景
ChatGPT和文心一言是两种不同的自然语言生成模型,它们有各自的优点和缺点。 ChatGPT(Generative Pre-trained Transformer)是由OpenAI开发的生成式AI模型,它在庞大的文本数据集上进行了预训练,并可以根据输入生成具有上…...

架构师之超时未支付的订单进行取消操作的几种解决方案
今天给大家上一盘硬菜,并且是支付中非常重要的一个技术解决方案,有这块业务的同学注意自己尝试一把哈! 一、需求如下: 生成订单30分钟未支付,自动取消 生成订单60秒后,给用户发短信 对上述的需求,我们给…...

业务系统对接大模型的基础方案:架构设计与关键步骤
业务系统对接大模型:架构设计与关键步骤 在当今数字化转型的浪潮中,大语言模型(LLM)已成为企业提升业务效率和创新能力的关键技术之一。将大模型集成到业务系统中,不仅可以优化用户体验,还能为业务决策提供…...

7.4.分块查找
一.分块查找的算法思想: 1.实例: 以上述图片的顺序表为例, 该顺序表的数据元素从整体来看是乱序的,但如果把这些数据元素分成一块一块的小区间, 第一个区间[0,1]索引上的数据元素都是小于等于10的, 第二…...
鸿蒙中用HarmonyOS SDK应用服务 HarmonyOS5开发一个医院查看报告小程序
一、开发环境准备 工具安装: 下载安装DevEco Studio 4.0(支持HarmonyOS 5)配置HarmonyOS SDK 5.0确保Node.js版本≥14 项目初始化: ohpm init harmony/hospital-report-app 二、核心功能模块实现 1. 报告列表…...

Mac软件卸载指南,简单易懂!
刚和Adobe分手,它却总在Library里给你写"回忆录"?卸载的Final Cut Pro像电子幽灵般阴魂不散?总是会有残留文件,别慌!这份Mac软件卸载指南,将用最硬核的方式教你"数字分手术"࿰…...
C++中string流知识详解和示例
一、概览与类体系 C 提供三种基于内存字符串的流,定义在 <sstream> 中: std::istringstream:输入流,从已有字符串中读取并解析。std::ostringstream:输出流,向内部缓冲区写入内容,最终取…...
相机Camera日志分析之三十一:高通Camx HAL十种流程基础分析关键字汇总(后续持续更新中)
【关注我,后续持续新增专题博文,谢谢!!!】 上一篇我们讲了:有对最普通的场景进行各个日志注释讲解,但相机场景太多,日志差异也巨大。后面将展示各种场景下的日志。 通过notepad++打开场景下的日志,通过下列分类关键字搜索,即可清晰的分析不同场景的相机运行流程差异…...

BCS 2025|百度副总裁陈洋:智能体在安全领域的应用实践
6月5日,2025全球数字经济大会数字安全主论坛暨北京网络安全大会在国家会议中心隆重开幕。百度副总裁陈洋受邀出席,并作《智能体在安全领域的应用实践》主题演讲,分享了在智能体在安全领域的突破性实践。他指出,百度通过将安全能力…...
AI编程--插件对比分析:CodeRider、GitHub Copilot及其他
AI编程插件对比分析:CodeRider、GitHub Copilot及其他 随着人工智能技术的快速发展,AI编程插件已成为提升开发者生产力的重要工具。CodeRider和GitHub Copilot作为市场上的领先者,分别以其独特的特性和生态系统吸引了大量开发者。本文将从功…...
【HarmonyOS 5 开发速记】如何获取用户信息(头像/昵称/手机号)
1.获取 authorizationCode: 2.利用 authorizationCode 获取 accessToken:文档中心 3.获取手机:文档中心 4.获取昵称头像:文档中心 首先创建 request 若要获取手机号,scope必填 phone,permissions 必填 …...

分布式增量爬虫实现方案
之前我们在讨论的是分布式爬虫如何实现增量爬取。增量爬虫的目标是只爬取新产生或发生变化的页面,避免重复抓取,以节省资源和时间。 在分布式环境下,增量爬虫的实现需要考虑多个爬虫节点之间的协调和去重。 另一种思路:将增量判…...