当前位置: 首页 > news >正文

TensorFlow2实战-系列教程14:Resnet实战2

🧡💛💚TensorFlow2实战-系列教程 总目录

有任何问题欢迎在下面留言
本篇文章的代码运行界面均在Jupyter Notebook中进行
本篇文章配套的代码资源已经上传

Resnet实战1
Resnet实战2
Resnet实战3

4、训练脚本train.py解读------创建模型

def get_model():model = resnet50.ResNet50()if config.model == "resnet34":model = resnet34.ResNet34()if config.model == "resnet101":model = resnet101.ResNet101()if config.model == "resnet152":model = resnet152.ResNet152()model.build(input_shape=(None, config.image_height, config.image_width, config.channels))model.summary()tf.keras.utils.plot_model(model, to_file='model.png')return model# create model
model = get_model()

调用get_model()函数构建模型

get_model()函数:

  1. 通过resnet50.py调用ResNet50类,构建ResNet50模型
  2. 如果在配置参数中设置的是"resnet34"、“resnet101”、“resnet152”,则会对应使用(resnet34.py调用ResNet34类,构建ResNet34模型)、(resnet101.py调用ResNet101类,构建ResNet101模型)、(resnet152.py调用ResNet152类,构建ResNet152模型)
  3. 准备模型以供训练或评估,
  4. 输出模型的概览
  5. 创建了模型的结构图,plot_model 函数从 Keras 工具包中生成模型的可视化表示,指定了保存路径

5、模型构建解析------models/resnet50.py

import tensorflow as tf
from models.residual_block import build_res_block_2
from config import NUM_CLASSESclass ResNet50(tf.keras.Model):def __init__(self, num_classes=NUM_CLASSES):super(ResNet50, self).__init__()self.pre1 = tf.keras.layers.Conv2D(filters=64, kernel_size=(7, 7), strides=2, padding='same')self.pre2 = tf.keras.layers.BatchNormalization()self.pre3 = tf.keras.layers.Activation(tf.keras.activations.relu)self.pre4 = tf.keras.layers.MaxPool2D(pool_size=(3, 3), strides=2)self.layer1 = build_res_block_2(filter_num=64, blocks=3)self.layer2 = build_res_block_2(filter_num=128, blocks=4, stride=2)self.layer3 = build_res_block_2(filter_num=256, blocks=6, stride=2)self.layer4 = build_res_block_2(filter_num=512, blocks=3, stride=2)self.avgpool = tf.keras.layers.GlobalAveragePooling2D()self.fc1 = tf.keras.layers.Dense(units=1000, activation=tf.keras.activations.relu)self.drop_out = tf.keras.layers.Dropout(rate=0.5)self.fc2 = tf.keras.layers.Dense(units=num_classes, activation=tf.keras.activations.softmax)def call(self, inputs, training=None, mask=None):pre1 = self.pre1(inputs)pre2 = self.pre2(pre1, training=training)pre3 = self.pre3(pre2)pre4 = self.pre4(pre3)l1 = self.layer1(pre4, training=training)l2 = self.layer2(l1, training=training)l3 = self.layer3(l2, training=training)l4 = self.layer4(l3, training=training)avgpool = self.avgpool(l4)fc1 = self.fc1(avgpool)drop = self.drop_out(fc1)out = self.fc2(drop)return out

class ResNet50(tf.keras.Model),这个类定义了ResNet50模型的结构,以及前向传播的方式、顺序

ResNet50类解析:

  1. 构造函数,传入了预测的类别数
  2. 初始化
  3. pre1 ,定义一个二维卷积,输出64个特征图,7x7的卷积,步长为2
  4. pre2 ,定义一个批归一化
  5. pre3,定义一个ReLU激活函数
  6. pre4,一个二维的最大池化
  7. 依次通过build_res_block_2()函数定义4个残差块
  8. 定义一个全局平均池化
  9. 定义一个全连接层,输出维度为1000
  10. 定义一个dropout
  11. 定义一个输出层的全连接层
  12. 前向传播函数,传入输入值
  13. 依次经过pre1、pre2、pre3、pre4,即卷积、批归一化、ReLU、最大池化
  14. 依次经过layer1 、layer2 、layer3 、layer4 等四个残差块
  15. 将layer4 的输出经过平局池化
  16. 依次经过两个全连接层

6、模型构建解析------models/residual_block.py

  • BottleNeck类
  • build_res_block_2()函数
  • build_res_block_2()函数通过调用BottleNeck类构建残差块
class BottleNeck(tf.keras.layers.Layer):def __init__(self, filter_num, stride=1,with_downsample=True):super(BottleNeck, self).__init__()self.with_downsample = with_downsampleself.conv1 = tf.keras.layers.Conv2D(filters=filter_num, kernel_size=(1, 1), strides=1, padding='same')self.bn1 = tf.keras.layers.BatchNormalization()self.conv2 = tf.keras.layers.Conv2D(filters=filter_num, kernel_size=(3, 3), strides=stride, padding='same')self.bn2 = tf.keras.layers.BatchNormalization()self.conv3 = tf.keras.layers.Conv2D(filters=filter_num * 4, kernel_size=(1, 1), strides=1, padding='same')self.bn3 = tf.keras.layers.BatchNormalization()self.downsample = tf.keras.Sequential()self.downsample.add(tf.keras.layers.Conv2D(filters=filter_num * 4, kernel_size=(1, 1), strides=stride))self.downsample.add(tf.keras.layers.BatchNormalization())def call(self, inputs, training=None):identity = self.downsample(inputs)conv1 = self.conv1(inputs)bn1 = self.bn1(conv1, training=training)relu1 = tf.nn.relu(bn1)conv2 = self.conv2(relu1)bn2 = self.bn2(conv2, training=training)relu2 = tf.nn.relu(bn2)conv3 = self.conv3(relu2)bn3 = self.bn3(conv3, training=training)if self.with_downsample == True:output = tf.nn.relu(tf.keras.layers.add([identity, bn3]))else:output = tf.nn.relu(tf.keras.layers.add([inputs, bn3]))return output

BottleNeck类解析:

  1. 继承tf.keras.layers.Layer
  2. 构造函数,传入 特征图个数、步长、是否下采样等参数
  3. 初始化
  4. 是否进行下采样参数
  5. 定义一个1x1,步长为1的二维卷积conv1
  6. conv1 对应的批归一化
  7. 定义一个3x3,步长为1的二维卷积conv2
  8. conv2 对应的批归一化
  9. 定义一个3x3,步长为1的二维卷积conv2
  10. conv3 对应的批归一化
  11. 定义一个下采样层(self.downsample),这个层是一个包含卷积层和批量归一化的 Sequential 模型,用于匹配输入和残差的维度
  12. call()函数为前向传播
  13. 应用下采样
  14. 应用三层卷积和批量归一化以及对应的ReLU
  15. with_downsample == True:
  16. 启用下采样,将下采样后的输入(identity)与最后一个卷积层的输出(bn3)相加
  17. 没有启用下采样,将原始输入(inputs)与最后一个卷积层的输出(bn3)相加
def build_res_block_2(filter_num, blocks, stride=1):res_block = tf.keras.Sequential()res_block.add(BottleNeck(filter_num, stride=stride))for _ in range(1, blocks):res_block.add(BottleNeck(filter_num, stride=1,with_downsample=False))    return res_block

build_res_block_2函数解析:

  1. 这个函数构建了一个包含多个BottleNeck层的残差块
  2. filter_num 是每个瓶颈层内卷积层的过滤器数量
  3. blocks 是要添加到顺序模型中的瓶颈层的数量
  4. stride 是卷积的步长,默认为 1
  5. 该函数初始化一个 Sequential 模型,并添加一个 BottleNeck 层作为第一层
  6. 然后,它迭代地添加额外的 BottleNeck 层,每个层的 stride=1 且
    with_downsample=False(除第一个之外)
  7. 此函数返回组装好的顺序模型,代表一个残差块

Resnet实战1
Resnet实战2
Resnet实战3

相关文章:

TensorFlow2实战-系列教程14:Resnet实战2

🧡💛💚TensorFlow2实战-系列教程 总目录 有任何问题欢迎在下面留言 本篇文章的代码运行界面均在Jupyter Notebook中进行 本篇文章配套的代码资源已经上传 Resnet实战1 Resnet实战2 Resnet实战3 4、训练脚本train.py解读------创建模型 def …...

编程笔记 html5cssjs 069 JavaScript Undefined数据类型

编程笔记 html5&css&js 069 JavaScript Undefined数据类型 一、undefined数据类型二、类型运算小结 在JavaScript中,undefined 是一种基本数据类型,它表示一个变量已经声明但未定义(即没有赋值)或者一个对象属性不存在。 …...

《区块链简易速速上手小册》第6章:区块链在金融服务领域的应用(2024 最新版)

文章目录 6.1 金融服务中的区块链6.1.1 金融服务中区块链的基础6.1.2 主要案例:跨境支付6.1.3 拓展案例 1:去中心化金融(DeFi)6.1.4 拓展案例 2:代币化资产 6.2 区块链在支付系统中的作用6.2.1 支付系统中区块链的基础…...

【消息队列】kafka整理

kafka整理 整理kafka基本知识供回顾。...

python--杂识--16--代理密码中包含特殊字符

1 安装nginx 2 centos环境安装 yum install httpd-tools3 nginx.conf /etc/nginx/conf/nginx.conf #user nobody; worker_processes 1;#error_log logs/error.log; #error_log logs/error.log notice; #error_log logs/error.log info;#pid logs/nginx.pid;e…...

【Git】05 分离头指针

文章目录 一、分离头指针二、创建分支三、比较commit内容四、总结 一、分离头指针 正常情况下,在通过git checkout命令切换分支时,在命令后面跟着的是分支名(例如master、temp等)或分支名对应commit的哈希值。 非正常情况下&…...

【Tomcat与网络9】提高Tomcat启动速度的八大措施

本文我们来看一下如何对Tomcat进行调优,我们对于Tomcat的调优主要集中在三个方面:提高启动速度、提高系统稳定性和提高并发能力,后两者很多时候是相辅相成的,我们放在一起看。 Tomcat现在一般都嵌入在SpringBoot里,因…...

蓝桥杯嵌入式第七届真题(完成) STM32G431

蓝桥杯嵌入式第七届真题(完成) STM32G431 题目 相关文件 main.c /* USER CODE BEGIN Header */ /********************************************************************************* file : main.c* brief : Main program body**********************…...

如何降低视频RTSP解码延迟

降低RTSP(Real-Time Streaming Protocol)视频流的解码延迟涉及到网络传输和解码处理的优化。以下是一些常见的方法: 选择低延迟的解码器:使用专为低延迟优化的解码器,例如一些定制的H.264或H.265解码器。 优化解码器设…...

【Golang】自定义logrus日志保存为日志文件

背景 为了方便查看日志,项目中需要把日志保存到对应的日志文件中,所以需要当前的配置,以使得日志能够保存到对应的日志文件中。 代码 import ("github.com/orandin/lumberjackrus""github.com/sirupsen/logrus" )func …...

【大厂AI课学习笔记】1.4 算法的进步(4)关于李飞飞团队的ImageNet

第一个图像数据库是ImageNet,由斯坦福大学的计算机科学家李飞飞推出。ImageNet是一个大型的可视化数据库,旨在推动计算机视觉领域的研究。这个数据库包含了数以百万计的手工标记的图像,涵盖了数千个不同的类别。 基于ImageNet数据库&#xf…...

【Linux笔记】缓冲区的概念到标准库的模拟实现

一、缓冲区 “缓冲区”这个概念相信大家或多或少都听说过,大家其实在C语言阶段就已经接触到“缓冲区”这个东西,但是相信大家在C语言阶段并没有真正弄懂缓冲区到底是个什么东西,也相信大家在C语言阶段也因为缓冲区的问题写出过各种bug。 其…...

【前端收藏】前端小作文-前端八股文知识总结(超万字超详细)持续更新

有了这个八股文不仅对你基础知识的巩固,不管你是几年老前端程序员,还是要去面试的,文章覆盖了前端常用及不常用的方方面面,都是前端日后能用上的,对你的前端知识有总结意义,看完后,懂的不懂的都…...

GNSS模块的惯导技术:引领定位科技的前沿

全球导航卫星系统(GNSS)模块的惯导技术是一项颇具前瞻性的科技,它结合了全球定位系统和惯性导航技术,为各个领域的定位需求提供了更为精准和可靠的解决方案。本文将深入探讨GNSS模块的惯导技术,以及它如何在多个领域中…...

Flutter 和 Android原生(Activity、Fragment)相互跳转、传参

前言 本文主要讲解 Flutter 和 Android原生之间,页面相互跳转、传参, 但其中用到了两端相互通信的知识,非常建议先看完这篇 讲解通信的文章: Flutter 与 Android原生 相互通信:BasicMessageChannel、MethodChannel、…...

Kubernetes基础(十一)-CNI网络插件用法和对比

1 CNI概述 1.1 什么是CNI? Kubernetes 本身并没有实现自己的容器网络,而是借助 CNI 标准,通过插件化的方式来集成各种网络插件,实现集群内部网络相互通信。 CNI(Container Network Interface,容器网络的…...

yo!这里是单例模式相关介绍

目录 前言 特殊类设计 只能在堆上创建对象的类 1.方法一(构造函数下手) 2.方法二(析构函数下手) 只能在栈上创建对象的类 单例模式 饿汉模式实现 懒汉模式实现 后记 前言 在面向找工作学习c的过程中,除了基本…...

2023年上-未来几年我要做什么

1月份,离职。 2月份,春节休假回来,中旬去参加了一个月的瑜伽培训,学会了倒立、鹤蝉。。。。 3月份,瑜伽培训结束,开始收拾房子,并调研各类项目。 4月份,参与了朋友的区块链项目 …...

智能汽车竞赛摄像头处理(3)——动态阈值二值化(大津法)

前言 (1)在上一节中,我们学习了对图像的固定二值化处理,可以将原始图像处理成二值化的黑白图像,这里面的本质就是将原来的二维数组进行了处理,处理后的二维数组里的元素都是0和255两个值。 (2…...

BGP协议

1.BGP相关概念 1.1 BGP的起源 不同自治系统(路由域)间路由交换与管理的需求推动了EGP的发展,但是EGP的算法简单,无法选路,从而被BGP取代。 自治系统:(AS) IGP:自治系统…...

网络编程(Modbus进阶)

思维导图 Modbus RTU(先学一点理论) 概念 Modbus RTU 是工业自动化领域 最广泛应用的串行通信协议,由 Modicon 公司(现施耐德电气)于 1979 年推出。它以 高效率、强健性、易实现的特点成为工业控制系统的通信标准。 包…...

docker详细操作--未完待续

docker介绍 docker官网: Docker:加速容器应用程序开发 harbor官网:Harbor - Harbor 中文 使用docker加速器: Docker镜像极速下载服务 - 毫秒镜像 是什么 Docker 是一种开源的容器化平台,用于将应用程序及其依赖项(如库、运行时环…...

python/java环境配置

环境变量放一起 python: 1.首先下载Python Python下载地址:Download Python | Python.org downloads ---windows -- 64 2.安装Python 下面两个,然后自定义,全选 可以把前4个选上 3.环境配置 1)搜高级系统设置 2…...

Qt Widget类解析与代码注释

#include "widget.h" #include "ui_widget.h"Widget::Widget(QWidget *parent): QWidget(parent), ui(new Ui::Widget) {ui->setupUi(this); }Widget::~Widget() {delete ui; }//解释这串代码,写上注释 当然可以!这段代码是 Qt …...

【JVM】- 内存结构

引言 JVM:Java Virtual Machine 定义:Java虚拟机,Java二进制字节码的运行环境好处: 一次编写,到处运行自动内存管理,垃圾回收的功能数组下标越界检查(会抛异常,不会覆盖到其他代码…...

Go 语言接口详解

Go 语言接口详解 核心概念 接口定义 在 Go 语言中,接口是一种抽象类型,它定义了一组方法的集合: // 定义接口 type Shape interface {Area() float64Perimeter() float64 } 接口实现 Go 接口的实现是隐式的: // 矩形结构体…...

大语言模型如何处理长文本?常用文本分割技术详解

为什么需要文本分割? 引言:为什么需要文本分割?一、基础文本分割方法1. 按段落分割(Paragraph Splitting)2. 按句子分割(Sentence Splitting)二、高级文本分割策略3. 重叠分割(Sliding Window)4. 递归分割(Recursive Splitting)三、生产级工具推荐5. 使用LangChain的…...

最新SpringBoot+SpringCloud+Nacos微服务框架分享

文章目录 前言一、服务规划二、架构核心1.cloud的pom2.gateway的异常handler3.gateway的filter4、admin的pom5、admin的登录核心 三、code-helper分享总结 前言 最近有个活蛮赶的,根据Excel列的需求预估的工时直接打骨折,不要问我为什么,主要…...

学校时钟系统,标准考场时钟系统,AI亮相2025高考,赛思时钟系统为教育公平筑起“精准防线”

2025年#高考 将在近日拉开帷幕,#AI 监考一度冲上热搜。当AI深度融入高考,#时间同步 不再是辅助功能,而是决定AI监考系统成败的“生命线”。 AI亮相2025高考,40种异常行为0.5秒精准识别 2025年高考即将拉开帷幕,江西、…...

使用Spring AI和MCP协议构建图片搜索服务

目录 使用Spring AI和MCP协议构建图片搜索服务 引言 技术栈概览 项目架构设计 架构图 服务端开发 1. 创建Spring Boot项目 2. 实现图片搜索工具 3. 配置传输模式 Stdio模式(本地调用) SSE模式(远程调用) 4. 注册工具提…...