当前位置: 首页 > news >正文

Python深度学习基于Tensorflow(3)Tensorflow 构建模型

文章目录

        • 数据导入和数据可视化
        • 数据集制作以及预处理
        • 模型结构
        • 低阶 API 构建模型
        • 中阶 API 构建模型
        • 高阶 API 构建模型
        • 保存和导入模型

这里以实际项目CIFAR-10为例,分别使用低阶,中阶,高阶 API 搭建模型。

这里以CIFAR-10为数据集,CIFAR-10为小型数据集,一共包含10个类别的 RGB 彩色图像:飞机(airplane)、汽车(automobile)、鸟类(bird)、猫(cat)、鹿(deer)、狗(dog)、蛙类(frog)、马(horse)、船(ship)和卡车(truck)。图像的尺寸为 32×32(像素),3个通道 ,数据集中一共有 50000 张训练圄片和 10000 张测试图像。CIFAR-10数据集有3个版本,这里使用Python版本。

数据导入和数据可视化

这里不用书中给的CIFAR-10数据,直接使用TensorFlow自带的玩意导入数据,可能需要魔法,其实TensorFlow中的数据特别的经典。

![[Pasted image 20240506194103.png]]

接下来导入cifar10数据集并进行可视化展示

import matplotlib.pyplot as plt
import tensorflow as tf(x_train, y_train), (x_test, y_test) = tf.keras.datasets.cifar10.load_data()
# x_train.shape, y_train.shape, x_test.shape, y_test.shape
# ((50000, 32, 32, 3), (50000, 1), (10000, 32, 32, 3), (10000, 1))index_name = {0:'airplane',1:'automobile',2:'bird',3:'cat',4:'deer',5:'dog',6:'frog',7:'horse',8:'ship',9:'truck'
}def plot_100_img(imgs, labels):fig = plt.figure(figsize=(20,20))for i in range(10):for j in range(10):plt.subplot(10,10,i*10+j+1)plt.imshow(imgs[i*10+j])plt.title(index_name[labels[i*10+j][0]])plt.axis('off')plt.show()plot_100_img(x_test[:100])

![[Pasted image 20240506200312.png]]

数据集制作以及预处理

数据集预处理很简单就能实现,直接一行代码。

train_data = tf.data.Dataset.from_tensor_slices((x_train, y_train))# 提取出一行数据
# train_data.take(1).get_single_element()

这里接着对数据预处理操作,也很容易就能实现。

def process_data(img, label):img = tf.cast(img, tf.float32) / 255.0return img, labeltrain_data = train_data.map(process_data)# 提取出一行数据
# train_data.take(1).get_single_element()

这里对数据还有一些存储和提取操作

dataset 中 shuffle()、repeat()、batch()、prefetch()等函数的主要功能如下。
1)repeat(count=None) 表示重复此数据集 count 次,实际上,我们看到 repeat 往往是接在 shuffle 后面的。为何要这么做,而不是反过来,先 repeat 再 shuffle 呢? 如果shuffle 在 repeat 之后,epoch 与 epoch 之间的边界就会模糊,出现未遍历完数据,已经计算过的数据又出现的情况。
2)shuffle(buffer_size, seed=None, reshuffle_each_iteration=None) 表示将数据打乱,数值越大,混乱程度越大。为了完全打乱,buffer_size 应等于数据集的数量。
3)batch(batch_size, drop_remainder=False) 表示按照顺序取出 batch_size 大小数据,最后一次输出可能小于batch ,如果程序指定了每次必须输入进批次的大小,那么应将drop_remainder 设置为 True 以防止产生较小的批次,默认为 False。
4)prefetch(buffer_size) 表示使用一个后台线程以及一个buffer来缓存batch,提前为模型的执行程序准备好数据。一般来说,buffer的大小应该至少和每一步训练消耗的batch数量一致,也就是 GPU/TPU 的数量。我们也可以使用AUTOTUNE来设置。创建一个Dataset便可从该数据集中预提取元素,注意:examples.prefetch(2) 表示将预取2个元素(2个示例),而examples.batch(20).prefetch(2) 表示将预取2个元素(2个批次,每个批次有20个示例),buffer_size 表示预提取时将缓冲的最大元素数返回 Dataset。

![[Pasted image 20240506201344.png]]

最后我们对数据进行一些缓存操作

learning_rate = 0.0002
batch_size = 64
training_steps = 40000
display_step = 1000AUTOTUNE = tf.data.experimental.AUTOTUNE
train_data = train_data.map(process_data).shuffle(5000).repeat(training_steps).batch(batch_size).prefetch(buffer_size=AUTOTUNE)

目前数据准备完毕!

模型结构

模型的结构如下,现在使用低阶,中阶,高阶 API 来构建这一个模型

![[Pasted image 20240506202450.png]]

低阶 API 构建模型
import matplotlib.pyplot as plt
import tensorflow as tf## 定义模型
class CustomModel(tf.Module):def __init__(self, name=None):super(CustomModel, self).__init__(name=name)self.w1 = tf.Variable(tf.initializers.RandomNormal()([32*32*3, 256]))self.b1 = tf.Variable(tf.initializers.RandomNormal()([256]))self.w2 = tf.Variable(tf.initializers.RandomNormal()([256, 128]))self.b2 = tf.Variable(tf.initializers.RandomNormal()([128]))self.w3 = tf.Variable(tf.initializers.RandomNormal()([128, 64]))self.b3 = tf.Variable(tf.initializers.RandomNormal()([64]))self.w4 = tf.Variable(tf.initializers.RandomNormal()([64, 10]))self.b4 = tf.Variable(tf.initializers.RandomNormal()([10]))def __call__(self, x):x = tf.cast(x, tf.float32)x = tf.reshape(x, [x.shape[0], -1])x = tf.nn.relu(x @ self.w1 + self.b1)x = tf.nn.relu(x @ self.w2 + self.b2)x = tf.nn.relu(x @ self.w3 + self.b3)x = tf.nn.softmax(x @ self.w4 + self.b4)return x
model = CustomModel()## 定义损失
def compute_loss(y, y_pred):y_pred = tf.clip_by_value(y_pred, 1e-9, 1.)loss = tf.keras.losses.sparse_categorical_crossentropy(y, y_pred)return tf.reduce_mean(loss)## 定义优化器
optimizer = tf.keras.optimizers.Adam(learning_rate=0.0002)## 定义准确率
def compute_accuracy(y, y_pred):correct_pred = tf.equal(tf.argmax(y_pred, axis=1), tf.cast(tf.reshape(y, -1), tf.int64))correct_pred = tf.cast(correct_pred, tf.float32)return tf.reduce_mean(correct_pred)## 定义一次epoch
def train_one_epoch(x, y):with tf.GradientTape() as tape:y_pred = model(x)loss = compute_loss(y, y_pred)accuracy = compute_accuracy(y, y_pred)grads = tape.gradient(loss, model.trainable_variables)optimizer.apply_gradients(zip(grads, model.trainable_variables))return loss.numpy(), accuracy.numpy()## 开始训练loss_list, acc_list = [], []
for i, (batch_x, batch_y) in enumerate(train_data.take(1000), 1):loss, acc = train_one_epoch(batch_x, batch_y)loss_list.append(loss)acc_list.append(acc)if i % 10 == 0:print(f'第{i}次训练->', 'loss:' ,loss, 'acc:', acc)
中阶 API 构建模型
## 定义模型
class CustomModel(tf.Module):def __init__(self):super(CustomModel, self).__init__()self.flatten = tf.keras.layers.Flatten()self.dense_1 = tf.keras.layers.Dense(256, activation='relu')self.dense_2 = tf.keras.layers.Dense(128, activation='relu')self.dense_3 = tf.keras.layers.Dense(64, activation='relu')self.dense_4 = tf.keras.layers.Dense(10, activation='softmax')def __call__(self, x):x = self.flatten(x)x = self.dense_1(x)x = self.dense_2(x)x = self.dense_3(x)x = self.dense_4(x)return xmodel = CustomModel()## 定义损失以及准确率
compute_loss = tf.keras.losses.SparseCategoricalCrossentropy()
train_loss = tf.keras.metrics.Mean()
train_accuracy = tf.keras.metrics.SparseCategoricalAccuracy()## 定义优化器
optimizer = tf.keras.optimizers.Adam(learning_rate=0.0002)## 定义一次epoch
def train_one_epoch(x, y):with tf.GradientTape() as tape:y_pred = model(x)loss = compute_loss(y, y_pred)grads = tape.gradient(loss, model.trainable_variables)optimizer.apply_gradients(zip(grads, model.trainable_variables))train_loss(loss)train_accuracy(y, y_pred)## 开始训练
loss_list, accuracy_list = [], []
for i, (batch_x, batch_y) in enumerate(train_data.take(1000), 1):train_one_epoch(batch_x, batch_y)loss_list.append(train_loss.result())accuracy_list.append(train_accuracy.result())if i % 10 == 0:print(f"第{i}次训练: loss: {train_loss.result()} accuarcy: {train_accuracy.result()}")
高阶 API 构建模型
## 定义模型
model = tf.keras.Sequential([tf.keras.layers.Input(shape=[32,32,3]),tf.keras.layers.Flatten(),tf.keras.layers.Dense(256, activation='relu'),tf.keras.layers.Dense(128, activation='relu'),tf.keras.layers.Dense(64, activation='relu'),tf.keras.layers.Dense(10, activation='softmax'),
])## 定义optimizer,loss, accuracy
model.compile(optimizer=tf.keras.optimizers.Adam(learning_rate=0.0002),loss = tf.keras.losses.SparseCategoricalCrossentropy(),metrics=['accuracy']
)## 开始训练
model.fit(train_data.take(10000))
保存和导入模型

保存模型

tf.keras.models.save_model(model, 'model_folder')

导入模型

model = tf.keras.models.load_model('model_folder')

相关文章:

Python深度学习基于Tensorflow(3)Tensorflow 构建模型

文章目录 数据导入和数据可视化数据集制作以及预处理模型结构低阶 API 构建模型中阶 API 构建模型高阶 API 构建模型保存和导入模型 这里以实际项目CIFAR-10为例,分别使用低阶,中阶,高阶 API 搭建模型。 这里以CIFAR-10为数据集,C…...

火爆多年的抖音小店,2024年想要入驻需要什么条件呢?

大家好,我是电商糖果 我相信现在只要会上网的年轻人,对抖音小店一定不会感觉陌生。 它最近几年的风头,可是远远超过某宝,某多多了。 不少抖音用户也有了在抖音购物的习惯,现在的抖音上入驻了上百万家电商商家。 这…...

STM32G030C8T6:EEPROM读写实验(I2C通信)

本专栏记录STM32开发各个功能的详细过程,方便自己后续查看,当然也供正在入门STM32单片机的兄弟们参考; 本小节的目标是,系统主频64 MHZ,采用高速外部晶振,实现PB11,PB10 引脚模拟I2C 时序,对M24C08 的EEPRO…...

使用Git管理github的代码库-上

1、下载安装Git https://download.csdn.net/download/notfindjob/11451730?spm1001.2014.3001.5503 2、注册一个github的账号(已经注册的,可略过这一步) 3、打开git命令行,配置github账号 git config --global user.name &quo…...

经典文献阅读之--D-Map(无需射线投射的高分辨率激光雷达传感器的占据栅格地图)

0. 简介 占用地图是机器人系统中推理环境未知和已知区域的基本组成部分。《Occupancy Grid Mapping without Ray-Casting for High-resolution LiDAR Sensors》介绍了一种高分辨率LiDAR传感器的高效占用地图框架,称为D-Map。该框架引入了三个主要创新来解决占用地图…...

开源免费的定时任务管理系统:Gocron

Gocron:精准调度未来,你的全能定时任务管理工具!- 精选真开源,释放新价值。 概览 Gocron是github上一个开源免费的定时任务管理系统。它使用Go语言开发,是一个轻量级定时任务集中调度和管理系统,用于替代L…...

从零开始详解OpenCV车道线检测

前言 车道线检测是智能驾驶和智能交通系统中的重要组成部分,对于提高道路安全、交通效率和驾驶舒适性具有重要意义。在本篇文章中将介绍使用OpenCV进行车道线的检测 详解 导入包 import cv2 import matplotlib.pyplot as plt import numpy as np读入图像并灰度化…...

【Java代码审计】逻辑漏洞篇

【Java代码审计】逻辑漏洞篇 逻辑漏洞概述常见逻辑漏洞点 逻辑漏洞概述 逻辑漏洞一般是由于源程序自身逻辑存在缺陷,导致攻击者可以对逻辑缺陷进行深层次的利用。逻辑漏洞出现较为频繁的地方一般是登录验证逻辑、验证码校验逻辑、密码找回逻辑、权限校验逻辑以及支…...

SSH简介

SSH,全名叫Secure Shell,你可以想象它是一个超级安全的管道,专门用来远程操控电脑的。就好比你在家用遥控器指挥远处的电视换台,但比这高级多了,因为它是专门为电脑设计的。 为什么需要SSH? 在互联网的早期…...

Oracle的高级分组函数grouping和grouping_id

在网上对Oracle的高级分组函数grouping和grouping_id的讲解并不多,特别是grouping_id,还有解说有误的。经过1天研究,已经完全掌握了两个函数的作用和用法,下面简单的讲述即可明白。下面给大家分享。 GROUPING 函数 语法:grouping(表达式) 作用: GROUPING将超聚…...

SqlServer 查询数据库 和 数据表 大小的语句

–Sqlserver 查询数据库 大小 SELECT * FROM (SELECT DB_NAME(database_id) AS DatabaseName,type_desc AS FileType,name AS FileName,size * 8 / 1024/1024 AS FileSizeGBFROM sys.master_filesWHERE type 0 -- 数据文件AND state 0 -- 在线状态 ) T1 ORDER BY FileSizeG…...

特殊类的设计与单例模式

1、特殊类的设计 如何设计出一个创建出的对象只能在堆上的类?将类的默认构造函数设置为私有,再将类的拷贝构造函数设置为delete,设置静态函数GetObj,内部调用new HeapOnly,这样就只能在堆上开辟空间。 class HeapOnly…...

MySQL从入门到高级 --- 6.函数

文章目录 第六章:6.函数6.1 聚合函数6.2 数学函数6.3 字符串函数6.4 日期函数6.4.1 日期格式 6.5 控制流函数6.5.1 if逻辑判断语句6.5.2 case when语句 6.6 窗口函数6.6.1 序号函数6.6.2 开窗聚合函数6.6.3 分布函数6.6.4 前后函数6.6.5 头尾函数6.6.6 其他函数6.7 …...

Qt---信号和槽

一、信号和槽机制 所谓信号槽,实际就是观察者模式。当某个事件发生之后,比如,按钮检测到自己被点击了一下,它就会发出一个信号(signal)。这种发出是没有目的的,类似广播。如果有对象对这个信号…...

POCEXP编写—文件上传案例

POC&EXP编写—文件上传案例 1. 前言2. 文件上传案例2.1. Burp抓包2.2. 基础代码实践2.2.1. 优化代码 2.3. 整体代码2.3.1. 木马测试 1. 前言 之前的文章基本上都是一些相对来说都是验证类的或者说是一些代码执行类的,相对来说都不是太复杂,而这篇会…...

C#知识|上位机UI设计-详情窗体设计思路及流程(实例)

哈喽,你好啊,我是雷工! 上两节练习记录了登录窗体和主窗体的实现过程,本节继续练习内容窗体的实现,以下为练习笔记。 01 详情窗体效果展示: 02 添加窗体并设置属性 在之前练习项目的基础上添加一个Windows窗体,设置名称为:FrmIPManage.cs 设置窗体的边框和标题栏的外…...

目标检测——印度车辆数据集

引言 亲爱的读者们,您是否在寻找某个特定的数据集,用于研究或项目实践?欢迎您在评论区留言,或者通过公众号私信告诉我,您想要的数据集的类型主题。小编会竭尽全力为您寻找,并在找到后第一时间与您分享。 …...

Zotero Word中插入带超链接的参考文献

Zotero 超链接 找了好多原代码,最接近能实施的为: https://blog.csdn.net/weixin_47244593/article/details/129072589 但是,就是向他说的一样会报错,我修改了代码,遇见报错的地方会直接跳过不执行,事后找…...

如何在服务器上下载,解压github上的代码

在github上找到对应仓库,找到平时download zip的地方,右键它,复制链接。在远程的终端里使用wget 链接 命令就可以得到zip了。 解压方法: -c :新建打包文件 -t :查看打包文件的内容含有哪些文件名 -x &…...

BGP学习二:BGP通告原则,BGP反射器,BGP路径属性细致讲解,新手小白无负担

目录 一.AS号 二.BGP路由生成 1.network 2.import-route引入 三.BGP通告原则 1.只发布最优且有效的路由 2.从EBGP获取的路由,会发布给所有对等体 3.水平分割原则 4.IBGP学习BGP默认不发送给EBGP,但如果也从IGP学习到了这条路由,就发…...

【Python】 -- 趣味代码 - 小恐龙游戏

文章目录 文章目录 00 小恐龙游戏程序设计框架代码结构和功能游戏流程总结01 小恐龙游戏程序设计02 百度网盘地址00 小恐龙游戏程序设计框架 这段代码是一个基于 Pygame 的简易跑酷游戏的完整实现,玩家控制一个角色(龙)躲避障碍物(仙人掌和乌鸦)。以下是代码的详细介绍:…...

超短脉冲激光自聚焦效应

前言与目录 强激光引起自聚焦效应机理 超短脉冲激光在脆性材料内部加工时引起的自聚焦效应,这是一种非线性光学现象,主要涉及光学克尔效应和材料的非线性光学特性。 自聚焦效应可以产生局部的强光场,对材料产生非线性响应,可能…...

安宝特方案丨XRSOP人员作业标准化管理平台:AR智慧点检验收套件

在选煤厂、化工厂、钢铁厂等过程生产型企业,其生产设备的运行效率和非计划停机对工业制造效益有较大影响。 随着企业自动化和智能化建设的推进,需提前预防假检、错检、漏检,推动智慧生产运维系统数据的流动和现场赋能应用。同时,…...

SCAU期末笔记 - 数据分析与数据挖掘题库解析

这门怎么题库答案不全啊日 来简单学一下子来 一、选择题(可多选) 将原始数据进行集成、变换、维度规约、数值规约是在以下哪个步骤的任务?(C) A. 频繁模式挖掘 B.分类和预测 C.数据预处理 D.数据流挖掘 A. 频繁模式挖掘:专注于发现数据中…...

【JVM】- 内存结构

引言 JVM:Java Virtual Machine 定义:Java虚拟机,Java二进制字节码的运行环境好处: 一次编写,到处运行自动内存管理,垃圾回收的功能数组下标越界检查(会抛异常,不会覆盖到其他代码…...

微信小程序 - 手机震动

一、界面 <button type"primary" bindtap"shortVibrate">短震动</button> <button type"primary" bindtap"longVibrate">长震动</button> 二、js逻辑代码 注&#xff1a;文档 https://developers.weixin.qq…...

04-初识css

一、css样式引入 1.1.内部样式 <div style"width: 100px;"></div>1.2.外部样式 1.2.1.外部样式1 <style>.aa {width: 100px;} </style> <div class"aa"></div>1.2.2.外部样式2 <!-- rel内表面引入的是style样…...

微信小程序云开发平台MySQL的连接方式

注&#xff1a;微信小程序云开发平台指的是腾讯云开发 先给结论&#xff1a;微信小程序云开发平台的MySQL&#xff0c;无法通过获取数据库连接信息的方式进行连接&#xff0c;连接只能通过云开发的SDK连接&#xff0c;具体要参考官方文档&#xff1a; 为什么&#xff1f; 因为…...

SpringTask-03.入门案例

一.入门案例 启动类&#xff1a; package com.sky;import lombok.extern.slf4j.Slf4j; import org.springframework.boot.SpringApplication; import org.springframework.boot.autoconfigure.SpringBootApplication; import org.springframework.cache.annotation.EnableCach…...

Unsafe Fileupload篇补充-木马的详细教程与木马分享(中国蚁剑方式)

在之前的皮卡丘靶场第九期Unsafe Fileupload篇中我们学习了木马的原理并且学了一个简单的木马文件 本期内容是为了更好的为大家解释木马&#xff08;服务器方面的&#xff09;的原理&#xff0c;连接&#xff0c;以及各种木马及连接工具的分享 文件木马&#xff1a;https://w…...