Python深度学习基于Tensorflow(3)Tensorflow 构建模型
文章目录
- 数据导入和数据可视化
- 数据集制作以及预处理
- 模型结构
- 低阶 API 构建模型
- 中阶 API 构建模型
- 高阶 API 构建模型
- 保存和导入模型
这里以实际项目CIFAR-10为例,分别使用低阶,中阶,高阶 API 搭建模型。
这里以CIFAR-10为数据集,CIFAR-10为小型数据集,一共包含10个类别的 RGB 彩色图像:飞机(airplane)、汽车(automobile)、鸟类(bird)、猫(cat)、鹿(deer)、狗(dog)、蛙类(frog)、马(horse)、船(ship)和卡车(truck)。图像的尺寸为 32×32(像素),3个通道 ,数据集中一共有 50000 张训练圄片和 10000 张测试图像。CIFAR-10数据集有3个版本,这里使用Python版本。
数据导入和数据可视化
这里不用书中给的CIFAR-10数据,直接使用TensorFlow自带的玩意导入数据,可能需要魔法,其实TensorFlow中的数据特别的经典。
接下来导入cifar10数据集并进行可视化展示
import matplotlib.pyplot as plt
import tensorflow as tf(x_train, y_train), (x_test, y_test) = tf.keras.datasets.cifar10.load_data()
# x_train.shape, y_train.shape, x_test.shape, y_test.shape
# ((50000, 32, 32, 3), (50000, 1), (10000, 32, 32, 3), (10000, 1))index_name = {0:'airplane',1:'automobile',2:'bird',3:'cat',4:'deer',5:'dog',6:'frog',7:'horse',8:'ship',9:'truck'
}def plot_100_img(imgs, labels):fig = plt.figure(figsize=(20,20))for i in range(10):for j in range(10):plt.subplot(10,10,i*10+j+1)plt.imshow(imgs[i*10+j])plt.title(index_name[labels[i*10+j][0]])plt.axis('off')plt.show()plot_100_img(x_test[:100])
数据集制作以及预处理
数据集预处理很简单就能实现,直接一行代码。
train_data = tf.data.Dataset.from_tensor_slices((x_train, y_train))# 提取出一行数据
# train_data.take(1).get_single_element()
这里接着对数据预处理操作,也很容易就能实现。
def process_data(img, label):img = tf.cast(img, tf.float32) / 255.0return img, labeltrain_data = train_data.map(process_data)# 提取出一行数据
# train_data.take(1).get_single_element()
这里对数据还有一些存储和提取操作
dataset 中 shuffle()、repeat()、batch()、prefetch()等函数的主要功能如下。
1)repeat(count=None) 表示重复此数据集 count 次,实际上,我们看到 repeat 往往是接在 shuffle 后面的。为何要这么做,而不是反过来,先 repeat 再 shuffle 呢? 如果shuffle 在 repeat 之后,epoch 与 epoch 之间的边界就会模糊,出现未遍历完数据,已经计算过的数据又出现的情况。
2)shuffle(buffer_size, seed=None, reshuffle_each_iteration=None) 表示将数据打乱,数值越大,混乱程度越大。为了完全打乱,buffer_size 应等于数据集的数量。
3)batch(batch_size, drop_remainder=False) 表示按照顺序取出 batch_size 大小数据,最后一次输出可能小于batch ,如果程序指定了每次必须输入进批次的大小,那么应将drop_remainder 设置为 True 以防止产生较小的批次,默认为 False。
4)prefetch(buffer_size) 表示使用一个后台线程以及一个buffer来缓存batch,提前为模型的执行程序准备好数据。一般来说,buffer的大小应该至少和每一步训练消耗的batch数量一致,也就是 GPU/TPU 的数量。我们也可以使用AUTOTUNE来设置。创建一个Dataset便可从该数据集中预提取元素,注意:examples.prefetch(2) 表示将预取2个元素(2个示例),而examples.batch(20).prefetch(2) 表示将预取2个元素(2个批次,每个批次有20个示例),buffer_size 表示预提取时将缓冲的最大元素数返回 Dataset。
最后我们对数据进行一些缓存操作
learning_rate = 0.0002
batch_size = 64
training_steps = 40000
display_step = 1000AUTOTUNE = tf.data.experimental.AUTOTUNE
train_data = train_data.map(process_data).shuffle(5000).repeat(training_steps).batch(batch_size).prefetch(buffer_size=AUTOTUNE)
目前数据准备完毕!
模型结构
模型的结构如下,现在使用低阶,中阶,高阶 API 来构建这一个模型
低阶 API 构建模型
import matplotlib.pyplot as plt
import tensorflow as tf## 定义模型
class CustomModel(tf.Module):def __init__(self, name=None):super(CustomModel, self).__init__(name=name)self.w1 = tf.Variable(tf.initializers.RandomNormal()([32*32*3, 256]))self.b1 = tf.Variable(tf.initializers.RandomNormal()([256]))self.w2 = tf.Variable(tf.initializers.RandomNormal()([256, 128]))self.b2 = tf.Variable(tf.initializers.RandomNormal()([128]))self.w3 = tf.Variable(tf.initializers.RandomNormal()([128, 64]))self.b3 = tf.Variable(tf.initializers.RandomNormal()([64]))self.w4 = tf.Variable(tf.initializers.RandomNormal()([64, 10]))self.b4 = tf.Variable(tf.initializers.RandomNormal()([10]))def __call__(self, x):x = tf.cast(x, tf.float32)x = tf.reshape(x, [x.shape[0], -1])x = tf.nn.relu(x @ self.w1 + self.b1)x = tf.nn.relu(x @ self.w2 + self.b2)x = tf.nn.relu(x @ self.w3 + self.b3)x = tf.nn.softmax(x @ self.w4 + self.b4)return x
model = CustomModel()## 定义损失
def compute_loss(y, y_pred):y_pred = tf.clip_by_value(y_pred, 1e-9, 1.)loss = tf.keras.losses.sparse_categorical_crossentropy(y, y_pred)return tf.reduce_mean(loss)## 定义优化器
optimizer = tf.keras.optimizers.Adam(learning_rate=0.0002)## 定义准确率
def compute_accuracy(y, y_pred):correct_pred = tf.equal(tf.argmax(y_pred, axis=1), tf.cast(tf.reshape(y, -1), tf.int64))correct_pred = tf.cast(correct_pred, tf.float32)return tf.reduce_mean(correct_pred)## 定义一次epoch
def train_one_epoch(x, y):with tf.GradientTape() as tape:y_pred = model(x)loss = compute_loss(y, y_pred)accuracy = compute_accuracy(y, y_pred)grads = tape.gradient(loss, model.trainable_variables)optimizer.apply_gradients(zip(grads, model.trainable_variables))return loss.numpy(), accuracy.numpy()## 开始训练loss_list, acc_list = [], []
for i, (batch_x, batch_y) in enumerate(train_data.take(1000), 1):loss, acc = train_one_epoch(batch_x, batch_y)loss_list.append(loss)acc_list.append(acc)if i % 10 == 0:print(f'第{i}次训练->', 'loss:' ,loss, 'acc:', acc)
中阶 API 构建模型
## 定义模型
class CustomModel(tf.Module):def __init__(self):super(CustomModel, self).__init__()self.flatten = tf.keras.layers.Flatten()self.dense_1 = tf.keras.layers.Dense(256, activation='relu')self.dense_2 = tf.keras.layers.Dense(128, activation='relu')self.dense_3 = tf.keras.layers.Dense(64, activation='relu')self.dense_4 = tf.keras.layers.Dense(10, activation='softmax')def __call__(self, x):x = self.flatten(x)x = self.dense_1(x)x = self.dense_2(x)x = self.dense_3(x)x = self.dense_4(x)return xmodel = CustomModel()## 定义损失以及准确率
compute_loss = tf.keras.losses.SparseCategoricalCrossentropy()
train_loss = tf.keras.metrics.Mean()
train_accuracy = tf.keras.metrics.SparseCategoricalAccuracy()## 定义优化器
optimizer = tf.keras.optimizers.Adam(learning_rate=0.0002)## 定义一次epoch
def train_one_epoch(x, y):with tf.GradientTape() as tape:y_pred = model(x)loss = compute_loss(y, y_pred)grads = tape.gradient(loss, model.trainable_variables)optimizer.apply_gradients(zip(grads, model.trainable_variables))train_loss(loss)train_accuracy(y, y_pred)## 开始训练
loss_list, accuracy_list = [], []
for i, (batch_x, batch_y) in enumerate(train_data.take(1000), 1):train_one_epoch(batch_x, batch_y)loss_list.append(train_loss.result())accuracy_list.append(train_accuracy.result())if i % 10 == 0:print(f"第{i}次训练: loss: {train_loss.result()} accuarcy: {train_accuracy.result()}")
高阶 API 构建模型
## 定义模型
model = tf.keras.Sequential([tf.keras.layers.Input(shape=[32,32,3]),tf.keras.layers.Flatten(),tf.keras.layers.Dense(256, activation='relu'),tf.keras.layers.Dense(128, activation='relu'),tf.keras.layers.Dense(64, activation='relu'),tf.keras.layers.Dense(10, activation='softmax'),
])## 定义optimizer,loss, accuracy
model.compile(optimizer=tf.keras.optimizers.Adam(learning_rate=0.0002),loss = tf.keras.losses.SparseCategoricalCrossentropy(),metrics=['accuracy']
)## 开始训练
model.fit(train_data.take(10000))
保存和导入模型
保存模型
tf.keras.models.save_model(model, 'model_folder')
导入模型
model = tf.keras.models.load_model('model_folder')
相关文章:

Python深度学习基于Tensorflow(3)Tensorflow 构建模型
文章目录 数据导入和数据可视化数据集制作以及预处理模型结构低阶 API 构建模型中阶 API 构建模型高阶 API 构建模型保存和导入模型 这里以实际项目CIFAR-10为例,分别使用低阶,中阶,高阶 API 搭建模型。 这里以CIFAR-10为数据集,C…...

火爆多年的抖音小店,2024年想要入驻需要什么条件呢?
大家好,我是电商糖果 我相信现在只要会上网的年轻人,对抖音小店一定不会感觉陌生。 它最近几年的风头,可是远远超过某宝,某多多了。 不少抖音用户也有了在抖音购物的习惯,现在的抖音上入驻了上百万家电商商家。 这…...

STM32G030C8T6:EEPROM读写实验(I2C通信)
本专栏记录STM32开发各个功能的详细过程,方便自己后续查看,当然也供正在入门STM32单片机的兄弟们参考; 本小节的目标是,系统主频64 MHZ,采用高速外部晶振,实现PB11,PB10 引脚模拟I2C 时序,对M24C08 的EEPRO…...

使用Git管理github的代码库-上
1、下载安装Git https://download.csdn.net/download/notfindjob/11451730?spm1001.2014.3001.5503 2、注册一个github的账号(已经注册的,可略过这一步) 3、打开git命令行,配置github账号 git config --global user.name &quo…...

经典文献阅读之--D-Map(无需射线投射的高分辨率激光雷达传感器的占据栅格地图)
0. 简介 占用地图是机器人系统中推理环境未知和已知区域的基本组成部分。《Occupancy Grid Mapping without Ray-Casting for High-resolution LiDAR Sensors》介绍了一种高分辨率LiDAR传感器的高效占用地图框架,称为D-Map。该框架引入了三个主要创新来解决占用地图…...

开源免费的定时任务管理系统:Gocron
Gocron:精准调度未来,你的全能定时任务管理工具!- 精选真开源,释放新价值。 概览 Gocron是github上一个开源免费的定时任务管理系统。它使用Go语言开发,是一个轻量级定时任务集中调度和管理系统,用于替代L…...

从零开始详解OpenCV车道线检测
前言 车道线检测是智能驾驶和智能交通系统中的重要组成部分,对于提高道路安全、交通效率和驾驶舒适性具有重要意义。在本篇文章中将介绍使用OpenCV进行车道线的检测 详解 导入包 import cv2 import matplotlib.pyplot as plt import numpy as np读入图像并灰度化…...
【Java代码审计】逻辑漏洞篇
【Java代码审计】逻辑漏洞篇 逻辑漏洞概述常见逻辑漏洞点 逻辑漏洞概述 逻辑漏洞一般是由于源程序自身逻辑存在缺陷,导致攻击者可以对逻辑缺陷进行深层次的利用。逻辑漏洞出现较为频繁的地方一般是登录验证逻辑、验证码校验逻辑、密码找回逻辑、权限校验逻辑以及支…...
SSH简介
SSH,全名叫Secure Shell,你可以想象它是一个超级安全的管道,专门用来远程操控电脑的。就好比你在家用遥控器指挥远处的电视换台,但比这高级多了,因为它是专门为电脑设计的。 为什么需要SSH? 在互联网的早期…...
Oracle的高级分组函数grouping和grouping_id
在网上对Oracle的高级分组函数grouping和grouping_id的讲解并不多,特别是grouping_id,还有解说有误的。经过1天研究,已经完全掌握了两个函数的作用和用法,下面简单的讲述即可明白。下面给大家分享。 GROUPING 函数 语法:grouping(表达式) 作用: GROUPING将超聚…...
SqlServer 查询数据库 和 数据表 大小的语句
–Sqlserver 查询数据库 大小 SELECT * FROM (SELECT DB_NAME(database_id) AS DatabaseName,type_desc AS FileType,name AS FileName,size * 8 / 1024/1024 AS FileSizeGBFROM sys.master_filesWHERE type 0 -- 数据文件AND state 0 -- 在线状态 ) T1 ORDER BY FileSizeG…...
特殊类的设计与单例模式
1、特殊类的设计 如何设计出一个创建出的对象只能在堆上的类?将类的默认构造函数设置为私有,再将类的拷贝构造函数设置为delete,设置静态函数GetObj,内部调用new HeapOnly,这样就只能在堆上开辟空间。 class HeapOnly…...

MySQL从入门到高级 --- 6.函数
文章目录 第六章:6.函数6.1 聚合函数6.2 数学函数6.3 字符串函数6.4 日期函数6.4.1 日期格式 6.5 控制流函数6.5.1 if逻辑判断语句6.5.2 case when语句 6.6 窗口函数6.6.1 序号函数6.6.2 开窗聚合函数6.6.3 分布函数6.6.4 前后函数6.6.5 头尾函数6.6.6 其他函数6.7 …...

Qt---信号和槽
一、信号和槽机制 所谓信号槽,实际就是观察者模式。当某个事件发生之后,比如,按钮检测到自己被点击了一下,它就会发出一个信号(signal)。这种发出是没有目的的,类似广播。如果有对象对这个信号…...

POCEXP编写—文件上传案例
POC&EXP编写—文件上传案例 1. 前言2. 文件上传案例2.1. Burp抓包2.2. 基础代码实践2.2.1. 优化代码 2.3. 整体代码2.3.1. 木马测试 1. 前言 之前的文章基本上都是一些相对来说都是验证类的或者说是一些代码执行类的,相对来说都不是太复杂,而这篇会…...

C#知识|上位机UI设计-详情窗体设计思路及流程(实例)
哈喽,你好啊,我是雷工! 上两节练习记录了登录窗体和主窗体的实现过程,本节继续练习内容窗体的实现,以下为练习笔记。 01 详情窗体效果展示: 02 添加窗体并设置属性 在之前练习项目的基础上添加一个Windows窗体,设置名称为:FrmIPManage.cs 设置窗体的边框和标题栏的外…...

目标检测——印度车辆数据集
引言 亲爱的读者们,您是否在寻找某个特定的数据集,用于研究或项目实践?欢迎您在评论区留言,或者通过公众号私信告诉我,您想要的数据集的类型主题。小编会竭尽全力为您寻找,并在找到后第一时间与您分享。 …...
Zotero Word中插入带超链接的参考文献
Zotero 超链接 找了好多原代码,最接近能实施的为: https://blog.csdn.net/weixin_47244593/article/details/129072589 但是,就是向他说的一样会报错,我修改了代码,遇见报错的地方会直接跳过不执行,事后找…...
如何在服务器上下载,解压github上的代码
在github上找到对应仓库,找到平时download zip的地方,右键它,复制链接。在远程的终端里使用wget 链接 命令就可以得到zip了。 解压方法: -c :新建打包文件 -t :查看打包文件的内容含有哪些文件名 -x &…...

BGP学习二:BGP通告原则,BGP反射器,BGP路径属性细致讲解,新手小白无负担
目录 一.AS号 二.BGP路由生成 1.network 2.import-route引入 三.BGP通告原则 1.只发布最优且有效的路由 2.从EBGP获取的路由,会发布给所有对等体 3.水平分割原则 4.IBGP学习BGP默认不发送给EBGP,但如果也从IGP学习到了这条路由,就发…...
在HarmonyOS ArkTS ArkUI-X 5.0及以上版本中,手势开发全攻略:
在 HarmonyOS 应用开发中,手势交互是连接用户与设备的核心纽带。ArkTS 框架提供了丰富的手势处理能力,既支持点击、长按、拖拽等基础单一手势的精细控制,也能通过多种绑定策略解决父子组件的手势竞争问题。本文将结合官方开发文档,…...

HTML 列表、表格、表单
1 列表标签 作用:布局内容排列整齐的区域 列表分类:无序列表、有序列表、定义列表。 例如: 1.1 无序列表 标签:ul 嵌套 li,ul是无序列表,li是列表条目。 注意事项: ul 标签里面只能包裹 li…...

Android15默认授权浮窗权限
我们经常有那种需求,客户需要定制的apk集成在ROM中,并且默认授予其【显示在其他应用的上层】权限,也就是我们常说的浮窗权限,那么我们就可以通过以下方法在wms、ams等系统服务的systemReady()方法中调用即可实现预置应用默认授权浮…...
DeepSeek 技术赋能无人农场协同作业:用 AI 重构农田管理 “神经网”
目录 一、引言二、DeepSeek 技术大揭秘2.1 核心架构解析2.2 关键技术剖析 三、智能农业无人农场协同作业现状3.1 发展现状概述3.2 协同作业模式介绍 四、DeepSeek 的 “农场奇妙游”4.1 数据处理与分析4.2 作物生长监测与预测4.3 病虫害防治4.4 农机协同作业调度 五、实际案例大…...

C# 求圆面积的程序(Program to find area of a circle)
给定半径r,求圆的面积。圆的面积应精确到小数点后5位。 例子: 输入:r 5 输出:78.53982 解释:由于面积 PI * r * r 3.14159265358979323846 * 5 * 5 78.53982,因为我们只保留小数点后 5 位数字。 输…...
安卓基础(aar)
重新设置java21的环境,临时设置 $env:JAVA_HOME "D:\Android Studio\jbr" 查看当前环境变量 JAVA_HOME 的值 echo $env:JAVA_HOME 构建ARR文件 ./gradlew :private-lib:assembleRelease 目录是这样的: MyApp/ ├── app/ …...
【JavaSE】多线程基础学习笔记
多线程基础 -线程相关概念 程序(Program) 是为完成特定任务、用某种语言编写的一组指令的集合简单的说:就是我们写的代码 进程 进程是指运行中的程序,比如我们使用QQ,就启动了一个进程,操作系统就会为该进程分配内存…...

HubSpot推出与ChatGPT的深度集成引发兴奋与担忧
上周三,HubSpot宣布已构建与ChatGPT的深度集成,这一消息在HubSpot用户和营销技术观察者中引发了极大的兴奋,但同时也存在一些关于数据安全的担忧。 许多网络声音声称,这对SaaS应用程序和人工智能而言是一场范式转变。 但向任何技…...

AI语音助手的Python实现
引言 语音助手(如小爱同学、Siri)通过语音识别、自然语言处理(NLP)和语音合成技术,为用户提供直观、高效的交互体验。随着人工智能的普及,Python开发者可以利用开源库和AI模型,快速构建自定义语音助手。本文由浅入深,详细介绍如何使用Python开发AI语音助手,涵盖基础功…...

Tauri2学习笔记
教程地址:https://www.bilibili.com/video/BV1Ca411N7mF?spm_id_from333.788.player.switch&vd_source707ec8983cc32e6e065d5496a7f79ee6 官方指引:https://tauri.app/zh-cn/start/ 目前Tauri2的教程视频不多,我按照Tauri1的教程来学习&…...