Python深度学习基于Tensorflow(3)Tensorflow 构建模型
文章目录
- 数据导入和数据可视化
- 数据集制作以及预处理
- 模型结构
- 低阶 API 构建模型
- 中阶 API 构建模型
- 高阶 API 构建模型
- 保存和导入模型
这里以实际项目CIFAR-10为例,分别使用低阶,中阶,高阶 API 搭建模型。
这里以CIFAR-10为数据集,CIFAR-10为小型数据集,一共包含10个类别的 RGB 彩色图像:飞机(airplane)、汽车(automobile)、鸟类(bird)、猫(cat)、鹿(deer)、狗(dog)、蛙类(frog)、马(horse)、船(ship)和卡车(truck)。图像的尺寸为 32×32(像素),3个通道 ,数据集中一共有 50000 张训练圄片和 10000 张测试图像。CIFAR-10数据集有3个版本,这里使用Python版本。
数据导入和数据可视化
这里不用书中给的CIFAR-10数据,直接使用TensorFlow自带的玩意导入数据,可能需要魔法,其实TensorFlow中的数据特别的经典。
接下来导入cifar10数据集并进行可视化展示
import matplotlib.pyplot as plt
import tensorflow as tf(x_train, y_train), (x_test, y_test) = tf.keras.datasets.cifar10.load_data()
# x_train.shape, y_train.shape, x_test.shape, y_test.shape
# ((50000, 32, 32, 3), (50000, 1), (10000, 32, 32, 3), (10000, 1))index_name = {0:'airplane',1:'automobile',2:'bird',3:'cat',4:'deer',5:'dog',6:'frog',7:'horse',8:'ship',9:'truck'
}def plot_100_img(imgs, labels):fig = plt.figure(figsize=(20,20))for i in range(10):for j in range(10):plt.subplot(10,10,i*10+j+1)plt.imshow(imgs[i*10+j])plt.title(index_name[labels[i*10+j][0]])plt.axis('off')plt.show()plot_100_img(x_test[:100])
数据集制作以及预处理
数据集预处理很简单就能实现,直接一行代码。
train_data = tf.data.Dataset.from_tensor_slices((x_train, y_train))# 提取出一行数据
# train_data.take(1).get_single_element()
这里接着对数据预处理操作,也很容易就能实现。
def process_data(img, label):img = tf.cast(img, tf.float32) / 255.0return img, labeltrain_data = train_data.map(process_data)# 提取出一行数据
# train_data.take(1).get_single_element()
这里对数据还有一些存储和提取操作
dataset 中 shuffle()、repeat()、batch()、prefetch()等函数的主要功能如下。
1)repeat(count=None) 表示重复此数据集 count 次,实际上,我们看到 repeat 往往是接在 shuffle 后面的。为何要这么做,而不是反过来,先 repeat 再 shuffle 呢? 如果shuffle 在 repeat 之后,epoch 与 epoch 之间的边界就会模糊,出现未遍历完数据,已经计算过的数据又出现的情况。
2)shuffle(buffer_size, seed=None, reshuffle_each_iteration=None) 表示将数据打乱,数值越大,混乱程度越大。为了完全打乱,buffer_size 应等于数据集的数量。
3)batch(batch_size, drop_remainder=False) 表示按照顺序取出 batch_size 大小数据,最后一次输出可能小于batch ,如果程序指定了每次必须输入进批次的大小,那么应将drop_remainder 设置为 True 以防止产生较小的批次,默认为 False。
4)prefetch(buffer_size) 表示使用一个后台线程以及一个buffer来缓存batch,提前为模型的执行程序准备好数据。一般来说,buffer的大小应该至少和每一步训练消耗的batch数量一致,也就是 GPU/TPU 的数量。我们也可以使用AUTOTUNE来设置。创建一个Dataset便可从该数据集中预提取元素,注意:examples.prefetch(2) 表示将预取2个元素(2个示例),而examples.batch(20).prefetch(2) 表示将预取2个元素(2个批次,每个批次有20个示例),buffer_size 表示预提取时将缓冲的最大元素数返回 Dataset。
最后我们对数据进行一些缓存操作
learning_rate = 0.0002
batch_size = 64
training_steps = 40000
display_step = 1000AUTOTUNE = tf.data.experimental.AUTOTUNE
train_data = train_data.map(process_data).shuffle(5000).repeat(training_steps).batch(batch_size).prefetch(buffer_size=AUTOTUNE)
目前数据准备完毕!
模型结构
模型的结构如下,现在使用低阶,中阶,高阶 API 来构建这一个模型
低阶 API 构建模型
import matplotlib.pyplot as plt
import tensorflow as tf## 定义模型
class CustomModel(tf.Module):def __init__(self, name=None):super(CustomModel, self).__init__(name=name)self.w1 = tf.Variable(tf.initializers.RandomNormal()([32*32*3, 256]))self.b1 = tf.Variable(tf.initializers.RandomNormal()([256]))self.w2 = tf.Variable(tf.initializers.RandomNormal()([256, 128]))self.b2 = tf.Variable(tf.initializers.RandomNormal()([128]))self.w3 = tf.Variable(tf.initializers.RandomNormal()([128, 64]))self.b3 = tf.Variable(tf.initializers.RandomNormal()([64]))self.w4 = tf.Variable(tf.initializers.RandomNormal()([64, 10]))self.b4 = tf.Variable(tf.initializers.RandomNormal()([10]))def __call__(self, x):x = tf.cast(x, tf.float32)x = tf.reshape(x, [x.shape[0], -1])x = tf.nn.relu(x @ self.w1 + self.b1)x = tf.nn.relu(x @ self.w2 + self.b2)x = tf.nn.relu(x @ self.w3 + self.b3)x = tf.nn.softmax(x @ self.w4 + self.b4)return x
model = CustomModel()## 定义损失
def compute_loss(y, y_pred):y_pred = tf.clip_by_value(y_pred, 1e-9, 1.)loss = tf.keras.losses.sparse_categorical_crossentropy(y, y_pred)return tf.reduce_mean(loss)## 定义优化器
optimizer = tf.keras.optimizers.Adam(learning_rate=0.0002)## 定义准确率
def compute_accuracy(y, y_pred):correct_pred = tf.equal(tf.argmax(y_pred, axis=1), tf.cast(tf.reshape(y, -1), tf.int64))correct_pred = tf.cast(correct_pred, tf.float32)return tf.reduce_mean(correct_pred)## 定义一次epoch
def train_one_epoch(x, y):with tf.GradientTape() as tape:y_pred = model(x)loss = compute_loss(y, y_pred)accuracy = compute_accuracy(y, y_pred)grads = tape.gradient(loss, model.trainable_variables)optimizer.apply_gradients(zip(grads, model.trainable_variables))return loss.numpy(), accuracy.numpy()## 开始训练loss_list, acc_list = [], []
for i, (batch_x, batch_y) in enumerate(train_data.take(1000), 1):loss, acc = train_one_epoch(batch_x, batch_y)loss_list.append(loss)acc_list.append(acc)if i % 10 == 0:print(f'第{i}次训练->', 'loss:' ,loss, 'acc:', acc)
中阶 API 构建模型
## 定义模型
class CustomModel(tf.Module):def __init__(self):super(CustomModel, self).__init__()self.flatten = tf.keras.layers.Flatten()self.dense_1 = tf.keras.layers.Dense(256, activation='relu')self.dense_2 = tf.keras.layers.Dense(128, activation='relu')self.dense_3 = tf.keras.layers.Dense(64, activation='relu')self.dense_4 = tf.keras.layers.Dense(10, activation='softmax')def __call__(self, x):x = self.flatten(x)x = self.dense_1(x)x = self.dense_2(x)x = self.dense_3(x)x = self.dense_4(x)return xmodel = CustomModel()## 定义损失以及准确率
compute_loss = tf.keras.losses.SparseCategoricalCrossentropy()
train_loss = tf.keras.metrics.Mean()
train_accuracy = tf.keras.metrics.SparseCategoricalAccuracy()## 定义优化器
optimizer = tf.keras.optimizers.Adam(learning_rate=0.0002)## 定义一次epoch
def train_one_epoch(x, y):with tf.GradientTape() as tape:y_pred = model(x)loss = compute_loss(y, y_pred)grads = tape.gradient(loss, model.trainable_variables)optimizer.apply_gradients(zip(grads, model.trainable_variables))train_loss(loss)train_accuracy(y, y_pred)## 开始训练
loss_list, accuracy_list = [], []
for i, (batch_x, batch_y) in enumerate(train_data.take(1000), 1):train_one_epoch(batch_x, batch_y)loss_list.append(train_loss.result())accuracy_list.append(train_accuracy.result())if i % 10 == 0:print(f"第{i}次训练: loss: {train_loss.result()} accuarcy: {train_accuracy.result()}")
高阶 API 构建模型
## 定义模型
model = tf.keras.Sequential([tf.keras.layers.Input(shape=[32,32,3]),tf.keras.layers.Flatten(),tf.keras.layers.Dense(256, activation='relu'),tf.keras.layers.Dense(128, activation='relu'),tf.keras.layers.Dense(64, activation='relu'),tf.keras.layers.Dense(10, activation='softmax'),
])## 定义optimizer,loss, accuracy
model.compile(optimizer=tf.keras.optimizers.Adam(learning_rate=0.0002),loss = tf.keras.losses.SparseCategoricalCrossentropy(),metrics=['accuracy']
)## 开始训练
model.fit(train_data.take(10000))
保存和导入模型
保存模型
tf.keras.models.save_model(model, 'model_folder')
导入模型
model = tf.keras.models.load_model('model_folder')
相关文章:

Python深度学习基于Tensorflow(3)Tensorflow 构建模型
文章目录 数据导入和数据可视化数据集制作以及预处理模型结构低阶 API 构建模型中阶 API 构建模型高阶 API 构建模型保存和导入模型 这里以实际项目CIFAR-10为例,分别使用低阶,中阶,高阶 API 搭建模型。 这里以CIFAR-10为数据集,C…...

火爆多年的抖音小店,2024年想要入驻需要什么条件呢?
大家好,我是电商糖果 我相信现在只要会上网的年轻人,对抖音小店一定不会感觉陌生。 它最近几年的风头,可是远远超过某宝,某多多了。 不少抖音用户也有了在抖音购物的习惯,现在的抖音上入驻了上百万家电商商家。 这…...

STM32G030C8T6:EEPROM读写实验(I2C通信)
本专栏记录STM32开发各个功能的详细过程,方便自己后续查看,当然也供正在入门STM32单片机的兄弟们参考; 本小节的目标是,系统主频64 MHZ,采用高速外部晶振,实现PB11,PB10 引脚模拟I2C 时序,对M24C08 的EEPRO…...

使用Git管理github的代码库-上
1、下载安装Git https://download.csdn.net/download/notfindjob/11451730?spm1001.2014.3001.5503 2、注册一个github的账号(已经注册的,可略过这一步) 3、打开git命令行,配置github账号 git config --global user.name &quo…...

经典文献阅读之--D-Map(无需射线投射的高分辨率激光雷达传感器的占据栅格地图)
0. 简介 占用地图是机器人系统中推理环境未知和已知区域的基本组成部分。《Occupancy Grid Mapping without Ray-Casting for High-resolution LiDAR Sensors》介绍了一种高分辨率LiDAR传感器的高效占用地图框架,称为D-Map。该框架引入了三个主要创新来解决占用地图…...

开源免费的定时任务管理系统:Gocron
Gocron:精准调度未来,你的全能定时任务管理工具!- 精选真开源,释放新价值。 概览 Gocron是github上一个开源免费的定时任务管理系统。它使用Go语言开发,是一个轻量级定时任务集中调度和管理系统,用于替代L…...

从零开始详解OpenCV车道线检测
前言 车道线检测是智能驾驶和智能交通系统中的重要组成部分,对于提高道路安全、交通效率和驾驶舒适性具有重要意义。在本篇文章中将介绍使用OpenCV进行车道线的检测 详解 导入包 import cv2 import matplotlib.pyplot as plt import numpy as np读入图像并灰度化…...
【Java代码审计】逻辑漏洞篇
【Java代码审计】逻辑漏洞篇 逻辑漏洞概述常见逻辑漏洞点 逻辑漏洞概述 逻辑漏洞一般是由于源程序自身逻辑存在缺陷,导致攻击者可以对逻辑缺陷进行深层次的利用。逻辑漏洞出现较为频繁的地方一般是登录验证逻辑、验证码校验逻辑、密码找回逻辑、权限校验逻辑以及支…...
SSH简介
SSH,全名叫Secure Shell,你可以想象它是一个超级安全的管道,专门用来远程操控电脑的。就好比你在家用遥控器指挥远处的电视换台,但比这高级多了,因为它是专门为电脑设计的。 为什么需要SSH? 在互联网的早期…...
Oracle的高级分组函数grouping和grouping_id
在网上对Oracle的高级分组函数grouping和grouping_id的讲解并不多,特别是grouping_id,还有解说有误的。经过1天研究,已经完全掌握了两个函数的作用和用法,下面简单的讲述即可明白。下面给大家分享。 GROUPING 函数 语法:grouping(表达式) 作用: GROUPING将超聚…...
SqlServer 查询数据库 和 数据表 大小的语句
–Sqlserver 查询数据库 大小 SELECT * FROM (SELECT DB_NAME(database_id) AS DatabaseName,type_desc AS FileType,name AS FileName,size * 8 / 1024/1024 AS FileSizeGBFROM sys.master_filesWHERE type 0 -- 数据文件AND state 0 -- 在线状态 ) T1 ORDER BY FileSizeG…...
特殊类的设计与单例模式
1、特殊类的设计 如何设计出一个创建出的对象只能在堆上的类?将类的默认构造函数设置为私有,再将类的拷贝构造函数设置为delete,设置静态函数GetObj,内部调用new HeapOnly,这样就只能在堆上开辟空间。 class HeapOnly…...

MySQL从入门到高级 --- 6.函数
文章目录 第六章:6.函数6.1 聚合函数6.2 数学函数6.3 字符串函数6.4 日期函数6.4.1 日期格式 6.5 控制流函数6.5.1 if逻辑判断语句6.5.2 case when语句 6.6 窗口函数6.6.1 序号函数6.6.2 开窗聚合函数6.6.3 分布函数6.6.4 前后函数6.6.5 头尾函数6.6.6 其他函数6.7 …...

Qt---信号和槽
一、信号和槽机制 所谓信号槽,实际就是观察者模式。当某个事件发生之后,比如,按钮检测到自己被点击了一下,它就会发出一个信号(signal)。这种发出是没有目的的,类似广播。如果有对象对这个信号…...

POCEXP编写—文件上传案例
POC&EXP编写—文件上传案例 1. 前言2. 文件上传案例2.1. Burp抓包2.2. 基础代码实践2.2.1. 优化代码 2.3. 整体代码2.3.1. 木马测试 1. 前言 之前的文章基本上都是一些相对来说都是验证类的或者说是一些代码执行类的,相对来说都不是太复杂,而这篇会…...

C#知识|上位机UI设计-详情窗体设计思路及流程(实例)
哈喽,你好啊,我是雷工! 上两节练习记录了登录窗体和主窗体的实现过程,本节继续练习内容窗体的实现,以下为练习笔记。 01 详情窗体效果展示: 02 添加窗体并设置属性 在之前练习项目的基础上添加一个Windows窗体,设置名称为:FrmIPManage.cs 设置窗体的边框和标题栏的外…...

目标检测——印度车辆数据集
引言 亲爱的读者们,您是否在寻找某个特定的数据集,用于研究或项目实践?欢迎您在评论区留言,或者通过公众号私信告诉我,您想要的数据集的类型主题。小编会竭尽全力为您寻找,并在找到后第一时间与您分享。 …...
Zotero Word中插入带超链接的参考文献
Zotero 超链接 找了好多原代码,最接近能实施的为: https://blog.csdn.net/weixin_47244593/article/details/129072589 但是,就是向他说的一样会报错,我修改了代码,遇见报错的地方会直接跳过不执行,事后找…...
如何在服务器上下载,解压github上的代码
在github上找到对应仓库,找到平时download zip的地方,右键它,复制链接。在远程的终端里使用wget 链接 命令就可以得到zip了。 解压方法: -c :新建打包文件 -t :查看打包文件的内容含有哪些文件名 -x &…...

BGP学习二:BGP通告原则,BGP反射器,BGP路径属性细致讲解,新手小白无负担
目录 一.AS号 二.BGP路由生成 1.network 2.import-route引入 三.BGP通告原则 1.只发布最优且有效的路由 2.从EBGP获取的路由,会发布给所有对等体 3.水平分割原则 4.IBGP学习BGP默认不发送给EBGP,但如果也从IGP学习到了这条路由,就发…...

使用VSCode开发Django指南
使用VSCode开发Django指南 一、概述 Django 是一个高级 Python 框架,专为快速、安全和可扩展的 Web 开发而设计。Django 包含对 URL 路由、页面模板和数据处理的丰富支持。 本文将创建一个简单的 Django 应用,其中包含三个使用通用基本模板的页面。在此…...
大学生职业发展与就业创业指导教学评价
这里是引用 作为软工2203/2204班的学生,我们非常感谢您在《大学生职业发展与就业创业指导》课程中的悉心教导。这门课程对我们即将面临实习和就业的工科学生来说至关重要,而您认真负责的教学态度,让课程的每一部分都充满了实用价值。 尤其让我…...
Angular微前端架构:Module Federation + ngx-build-plus (Webpack)
以下是一个完整的 Angular 微前端示例,其中使用的是 Module Federation 和 npx-build-plus 实现了主应用(Shell)与子应用(Remote)的集成。 🛠️ 项目结构 angular-mf/ ├── shell-app/ # 主应用&…...

FFmpeg:Windows系统小白安装及其使用
一、安装 1.访问官网 Download FFmpeg 2.点击版本目录 3.选择版本点击安装 注意这里选择的是【release buids】,注意左上角标题 例如我安装在目录 F:\FFmpeg 4.解压 5.添加环境变量 把你解压后的bin目录(即exe所在文件夹)加入系统变量…...
tomcat指定使用的jdk版本
说明 有时候需要对tomcat配置指定的jdk版本号,此时,我们可以通过以下方式进行配置 设置方式 找到tomcat的bin目录中的setclasspath.bat。如果是linux系统则是setclasspath.sh set JAVA_HOMEC:\Program Files\Java\jdk8 set JRE_HOMEC:\Program Files…...

麒麟系统使用-进行.NET开发
文章目录 前言一、搭建dotnet环境1.获取相关资源2.配置dotnet 二、使用dotnet三、其他说明总结 前言 麒麟系统的内核是基于linux的,如果需要进行.NET开发,则需要安装特定的应用。由于NET Framework 是仅适用于 Windows 版本的 .NET,所以要进…...

qt+vs Generated File下的moc_和ui_文件丢失导致 error LNK2001
qt 5.9.7 vs2013 qt add-in 2.3.2 起因是添加一个新的控件类,直接把源文件拖进VS的项目里,然后VS卡住十秒,然后编译就报一堆 error LNK2001 一看项目的Generated Files下的moc_和ui_文件丢失了一部分,导致编译的时候找不到了。因…...

java 局域网 rtsp 取流 WebSocket 推送到前端显示 低延迟
众所周知 摄像头取流推流显示前端延迟大 传统方法是服务器取摄像头的rtsp流 然后客户端连服务器 中转多了,延迟一定不小。 假设相机没有专网 公网 1相机自带推流 直接推送到云服务器 然后客户端拉去 2相机只有rtsp ,边缘服务器拉流推送到云服务器 …...

2025-06-01-Hive 技术及应用介绍
Hive 技术及应用介绍 参考资料 Hive 技术原理Hive 架构及应用介绍Hive - 小海哥哥 de - 博客园https://cwiki.apache.org/confluence/display/Hive/Home(官方文档) Apache Hive 是基于 Hadoop 构建的数据仓库工具,它为海量结构化数据提供类 SQL 的查询能力…...

【靶场】XXE-Lab xxe漏洞
前言 学习xxe漏洞,搭了个XXE-Lab的靶场 一、搭建靶场 现在需要登录,不知道用户名密码,先随便试试抓包 二、判断是否存在xxe漏洞 1.首先登录抓包 看到xml数据解析,由此判断和xxe漏洞有关,但还不确定xxe漏洞是否存在。 2.尝试xxe 漏洞 判断是否存在xxe漏洞 A.send to …...