当前位置: 首页 > news >正文

【TensorFlow2 之015】 在 TF 2.0 中实现 AlexNet

一、说明

       在这篇文章中,我们将展示如何在 TensorFlow 2.0 中实现基本的卷积神经网络 \(AlexNet\)。AlexNet 架构由 Alex Krizhevsky 设计,并与 Ilya Sutskever 和 Geoffrey Hinton 一起发布。并获得Image Net2012竞赛中冠军。

教程概述:

  1. 理论回顾
  2. 在 TensorFlow 2.0 中的实现

二 理论回顾

        现实生活中的计算机视觉问题需要大量高质量数据进行训练。过去,人们使用 CIFAR 和 NORB 数据集作为计算机视觉问题的基准数据集。然而,ImageNet竞赛改变了这一点。该数据集需要比以前更复杂的网络才能获得良好的结果。

        AlexNet 是 2012 年取得最佳结果的一种网络架构。它的 Top-5 错误率为 15.3%。第二好的成绩远远落后(26.2%)。

        该架构有大约 6000 万个参数,由以下层组成。

图层类型特征图尺寸内核大小跨步激活
图像1227×227
卷积9655×5511×114ReLU
最大池化9627×273×32
卷积25627×275×51ReLU
最大池化25613×133×32
卷积第384章13×133×31ReLU
卷积第384章13×133×31ReLU
卷积25613×133×31ReLU
最大池化2566×63×32
完全连接4096ReLU
完全连接4096ReLU
完全连接1000软最大

        在我们的例子中,我们将仅在 ImageNet 数据集中的两个类上训练模型,因此我们的最后一个全连接层将只有两个具有 Softmax 激活函数的神经元。

        有一些变化使得 AlexNet 与当时的其他网络不同。让我们看看是什么改变了历史!

2.1  重叠的池化层

        标准池化层汇总同一内核图中相邻神经元组的输出。传统上,相邻池单元总结的邻域不重叠。重叠池化层与标准池化层类似,只是计算 Max 的相邻窗口彼此重叠。

重叠池化与非重叠池化

2.2 ReLU 非线性

        评估神经元输出的传统方法是使用 sigmoid 或 tanh 激活函数。这两个函数固定在最小值和最大值之间,因此它们是饱和非线性的。然而,在 AlexNet 中,使用了修正线性单位函数,或者简称为 \(ReLU\)。该函数的阈值为\(0\)。这是一个非饱和激活函数。

        \(ReLU\) 函数需要更少的计算并允许更快的学习,这对在大型数据集上训练的大型模型的性能有很大影响。

2.3  局部响应标准化

        局部响应归一化 (LRN) 首次在 AlexNet 架构中引入,其中选择的激活函数是 \(ReLU\)。使用 LRN 的原因是为了鼓励 侧向抑制。 这是指神经元减少其邻居活动的能力。当我们使用 ReLU 激活函数处理神经元时,这非常有用。具有 \(ReLU\) 激活函数的神经元具有无界激活,我们需要 LRN 对其进行标准化。

三. TensorFlow 2.0中的实现

        交互式 Colab 笔记本可在以下链接找到

        让我们从导入所有必需的库开始

# Load the TensorBoard notebook extension
%load_ext tensorboard
import datetime
import numpy as np
import tensorflow as tf
import matplotlib.pyplot as pltfrom tensorflow.keras import Model
from tensorflow.keras.models import Sequential
from tensorflow.keras.utils import to_categorical
from tensorflow.keras.losses import categorical_crossentropy
from tensorflow.keras.preprocessing.image import ImageDataGenerator
from tensorflow.keras.layers import Dense, Flatten, Conv2D, MaxPooling2D, Dropout

        导入后,我们需要准备数据。在这里,我们将仅使用 ImageNet 数据集的一小部分。使用以下代码,您可以下载所有图像并将它们存储在文件夹中。

import cv2
import urllib
import requests
import PIL.Image
import numpy as np
from bs4 import BeautifulSoup#ship synset
page = requests.get("http://www.image-net.org/api/text/imagenet.synset.geturls?wnid=n04194289")
soup = BeautifulSoup(page.content, 'html.parser')
#bicycle synset
bikes_page = requests.get("http://www.image-net.org/api/text/imagenet.synset.geturls?wnid=n02834778")
bikes_soup = BeautifulSoup(bikes_page.content, 'html.parser')str_soup=str(soup)
split_urls=str_soup.split('\r\n')bikes_str_soup=str(bikes_soup)
bikes_split_urls=bikes_str_soup.split('\r\n')!mkdir /content/train
!mkdir /content/train/ships
!mkdir /content/train/bikes
!mkdir /content/validation
!mkdir /content/validation/ships
!mkdir /content/validation/bikesimg_rows, img_cols = 32, 32
input_shape = (img_rows, img_cols, 3)def url_to_image(url):resp = urllib.request.urlopen(url)image = np.asarray(bytearray(resp.read()), dtype="uint8")image = cv2.imdecode(image, cv2.IMREAD_COLOR)return imagen_of_training_images=100
for progress in range(n_of_training_images):if not split_urls[progress] == None:try:I = url_to_image(split_urls[progress])if (len(I.shape))==3:save_path = '/content/train/ships/img'+str(progress)+'.jpg'cv2.imwrite(save_path,I)except:Nonefor progress in range(n_of_training_images):if not bikes_split_urls[progress] == None:try:I = url_to_image(bikes_split_urls[progress])if (len(I.shape))==3:save_path = '/content/train/bikes/img'+str(progress)+'.jpg'cv2.imwrite(save_path,I)except:Nonefor progress in range(50):if not split_urls[progress] == None:try:I = url_to_image(split_urls[n_of_training_images+progress])if (len(I.shape))==3:save_path = '/content/validation/ships/img'+str(progress)+'.jpg'cv2.imwrite(save_path,I)except:Nonefor progress in range(50):if not bikes_split_urls[progress] == None:try:I = url_to_image(bikes_split_urls[n_of_training_images+progress])if (len(I.shape))==3:save_path = '/content/validation/bikes/img'+str(progress)+'.jpg'cv2.imwrite(save_path,I)except:None

        现在我们可以创建一个网络。原始 AlexNet 的最后一层有 1000 个神经元,但这里我们只使用一个。这是因为我们只将图像用于两个类。为了构建我们的卷积神经网络,我们将使用 Sequential API。

num_classes = 2# AlexNet model
class AlexNet(Sequential):def __init__(self, input_shape, num_classes):super().__init__()self.add(Conv2D(96, kernel_size=(11,11), strides= 4,padding= 'valid', activation= 'relu',input_shape= input_shape,kernel_initializer= 'he_normal'))self.add(MaxPooling2D(pool_size=(3,3), strides= (2,2),padding= 'valid', data_format= None))self.add(Conv2D(256, kernel_size=(5,5), strides= 1,padding= 'same', activation= 'relu',kernel_initializer= 'he_normal'))self.add(MaxPooling2D(pool_size=(3,3), strides= (2,2),padding= 'valid', data_format= None)) self.add(Conv2D(384, kernel_size=(3,3), strides= 1,padding= 'same', activation= 'relu',kernel_initializer= 'he_normal'))self.add(Conv2D(384, kernel_size=(3,3), strides= 1,padding= 'same', activation= 'relu',kernel_initializer= 'he_normal'))self.add(Conv2D(256, kernel_size=(3,3), strides= 1,padding= 'same', activation= 'relu',kernel_initializer= 'he_normal'))self.add(MaxPooling2D(pool_size=(3,3), strides= (2,2),padding= 'valid', data_format= None))self.add(Flatten())self.add(Dense(4096, activation= 'relu'))self.add(Dense(4096, activation= 'relu'))self.add(Dense(1000, activation= 'relu'))self.add(Dense(num_classes, activation= 'softmax'))self.compile(optimizer= tf.keras.optimizers.Adam(0.001),loss='categorical_crossentropy',metrics=['accuracy'])model = AlexNet((227, 227, 3), num_classes)

        创建模型后,我们定义一些重要的参数以供以后使用。此外,让我们创建图像数据生成器。\(AlexNet\)的参数非常多,有6000万个,这是一个巨大的数字。如果没有足够的数据,这将很可能导致过度拟合。因此,在这里,我们将利用数据增强技术,您可以在此处找到更多相关信息。

        出于同样的原因,AlexNet 中使用了 dropout 层。该技术包括以预定概率“关闭”神经元。这迫使每个神经元具有更强大的特征,可以与其他神经元一起使用。我们不会在这里使用 dropout 层,因为我们不会使用整个数据集。

# some training parameters
EPOCHS = 100
BATCH_SIZE = 32
image_height = 227
image_width = 227
train_dir = "train"
valid_dir = "validation"
model_dir = "my_model.h5"

train_datagen = ImageDataGenerator(rescale=1./255,rotation_range=10,width_shift_range=0.1,height_shift_range=0.1,shear_range=0.1,zoom_range=0.1)train_generator = train_datagen.flow_from_directory(train_dir,target_size=(image_height, image_width),color_mode="rgb",batch_size=BATCH_SIZE,seed=1,shuffle=True,class_mode="categorical")valid_datagen = ImageDataGenerator(rescale=1.0/255.0)
valid_generator = valid_datagen.flow_from_directory(valid_dir,target_size=(image_height, image_width),color_mode="rgb",batch_size=BATCH_SIZE,seed=7,shuffle=True,class_mode="categorical")
train_num = train_generator.samples
valid_num = valid_generator.samples

        现在我们可以设置TensorBoard并开始训练我们的模型。这样我们就可以实时跟踪模型性能。

log_dir="logs/fit/" + datetime.datetime.now().strftime("%Y%m%d-%H%M%S")
tensorboard_callback = tf.keras.callbacks.TensorBoard(log_dir=log_dir)
callback_list = [tensorboard_callback]# start training
model.fit(train_generator,epochs=EPOCHS,steps_per_epoch=train_num // BATCH_SIZE,validation_data=valid_generator,validation_steps=valid_num // BATCH_SIZE,callbacks=callback_list,verbose=0)# save the whole model
model.save(model_dir)%tensorboard --logdir logs/fit

        让我们使用我们的模型进行一些预测并将其可视化。

class_names = ['bike', 'ship']x_valid, label_batch  = next(iter(valid_generator))prediction_values = model.predict_classes(x_valid)# set up the figure
fig = plt.figure(figsize=(10, 6))
fig.subplots_adjust(left=0, right=1, bottom=0, top=1, hspace=0.05, wspace=0.05)# plot the images: each image is 227x227 pixels
for i in range(8):ax = fig.add_subplot(2, 4, i + 1, xticks=[], yticks=[])ax.imshow(x_valid[i,:],cmap=plt.cm.gray_r, interpolation='nearest')if prediction_values[i] == np.argmax(label_batch[i]):# label the image with the blue textax.text(3, 17, class_names[prediction_values[i]], color='blue', fontsize=14)else:# label the image with the red textax.text(3, 17, class_names[prediction_values[i]], color='red', fontsize=14)

 

四、概括

        在这篇文章中,我们展示了如何在 TensorFlow 2.0 中实现 \(AlexNet\)。我们只使用了 ImageNet 数据集的一部分,这就是为什么我们没有得到最好的结果。为了获得更高的准确性,需要更多的数据和更长的训练时间。

参考资料:

 数据黑客变种rs    深度学习 机器学习 TensorFlow    2020 年 2 月 29 日  |  0

相关文章:

【TensorFlow2 之015】 在 TF 2.0 中实现 AlexNet

一、说明 在这篇文章中,我们将展示如何在 TensorFlow 2.0 中实现基本的卷积神经网络 \(AlexNet\)。AlexNet 架构由 Alex Krizhevsky 设计,并与 Ilya Sutskever 和 Geoffrey Hinton 一起发布。并获得Image Net2012竞赛中冠军。 教程概述: 理论…...

Python进阶之迭代器

文章目录 前言一、迭代器介绍及作用1.可迭代对象2. 迭代器 二、常用函数和迭代器1.常用函数2.迭代器 三、总结结束语 💂 个人主页:风间琉璃🤟 版权: 本文由【风间琉璃】原创、在CSDN首发、需要转载请联系博主💬 如果文章对你有帮助、欢迎关注…...

Vue鼠标右键画矩形和Ctrl按键多选组件

效果图 说明 下面会贴出组件代码以及一个Demo&#xff0c;上面的效果图即为Demo的效果&#xff0c;建议直接将两份代码拷贝到自己的开发环境直接运行调试。 组件代码 <template><!-- 鼠标画矩形选择对象 --><div class"objects" ref"objectsR…...

【MySQL JDBC】使用Java连接MySQL数据库

一、什么是JDBC&#xff1f; 理解API的概念 API&#xff1a;Application Programing Interface -- 应用程序编程接口写好一个程序&#xff0c;这个程序需要给别人提供哪些功能&#xff1f;这些功能就是通过一些 函数/类 这样的方式来提供的。例如 Random、Scanner、ArrayList..…...

字节码学习之常见java语句的底层原理

文章目录 前言1. if语句字节码的解析 2. for循环字节码的解析 3. while循环4. switch语句5. try-catch语句6. i 和i的字节码7. try-catch-finally8. 参考文档 前言 上一章我们聊了《JVM字节码指令详解》 。本章我们学以致用&#xff0c;聊一下我们常见的一些java语句的特性底层…...

Godot C#连接信号不能像GDScirpt一样自动添加代码

前言 我网上找了好久&#xff0c;发现Godot 对于C# 的支持还有待增强 使用c#脚本有办法像gds那样连接节点自带信号时自动生成信号吗&#xff1f; 百度贴吧 Godot C# How To, Episode 9. Signals With Parameters | Godot Mono 解决方案 把信号拉长&#xff0c;看他的属性 修…...

快速自动化处理JavaScript渲染页面

在进行网络数据抓取时&#xff0c;许多网站使用了JavaScript来动态加载内容&#xff0c;这给传统的网络爬虫带来了一定的挑战。本文将介绍如何使用Selenium和ChromeDriver来实现自动化处理JavaScript渲染页面&#xff0c;并实现有效的数据抓取。 1、Selenium和ChromeDriver简介…...

通过API接口进行商品价格监控,可以按照以下步骤进行操作

要实现通过API接口进行商品价格监控&#xff0c;可以按照以下步骤进行操作&#xff1a; 申请平台账号并选择API接口&#xff1a;根据需要的功能&#xff0c;选择相应的API接口&#xff0c;例如商品API接口、店铺API接口、订单API接口等&#xff0c;这一步骤通常需要我们在相应…...

(vue3)大事记管理系统 文章管理页

[element-plus进阶] 文章列表渲染&#xff08;带搜索&到分页&#xff09; 表单架设&#xff1a;当前el-form标签配置一个inline属性&#xff0c;里面的元素就会在一行显示了 中英国际化处理&#xff1a;App.vue中el-config-provider标签包裹组件&#xff0c;意味着整个组…...

springboot 使用RocketMQ客户端生产消费消息DEMO

创建springboot项目省略 项目依赖 注意&#xff1a;当前客户端版本是 5.1.3 &#xff0c;安装的rocketmq服务的版本要与其对应 <properties><java.version>11</java.version><rocketmq-client-java-version>5.1.3</rocketmq-client-java-version&…...

第三章 内存管理 四、连续分配管理方式

目录 一、内存空间的分配与回收 1、连续分配管理方式 &#xff08;1&#xff09;、单一连续分配 优点&#xff1a; 缺点&#xff1a; &#xff08;2&#xff09;、固定分区分配 分区大小相等&#xff1a; 分区大小不等&#xff1a; &#xff08;3&#xff09;、动态分区…...

npm install报--4048错误和ERR_SOCKET_TIMEOUT问题解决方法之一

一、问题描述 学习vue数字大屏加载动漫效果时&#xff0c;在项目终端页面输入全局下载指令 npm install -g json-server 问题1、报--4048错误 会报如下错误 operation not permitted......errno: -4048code:EPERMsyscall: mkdir......The operation was reiected by your op…...

合并两个有序数组

给你两个按 非递减顺序 排列的整数数组 nums1 和 nums2&#xff0c;另有两个整数 m 和 n &#xff0c;分别表示 nums1 和 nums2 中的元素数目。 请你 合并 nums2 到 nums1 中&#xff0c;使合并后的数组同样按 非递减顺序 排列。 注意&#xff1a;最终&#xff0c;合并后数组…...

自动泊车系统设计学习笔记

1 概述 1.1 自动泊车系统研究现状 目前对于自动泊车系统的研究方法通常有两种实现方式&#xff1a; 整个泊车操作可以分为四个阶段&#xff1a;第一阶段车辆向前行驶进行车位识别&#xff0c;第二阶段车辆行驶到准备泊车时的待泊车区域&#xff0c;第三阶段车辆按照规划好的…...

基于Java的家电销售网站管理系统设计与实现(源码+lw+部署文档+讲解等)

文章目录 前言具体实现截图论文参考论文参考详细视频演示为什么选择我自己的网站自己的小程序&#xff08;小蔡coding&#xff09;有保障的售后福利 代码参考源码获取 前言 &#x1f497;博主介绍&#xff1a;✌全网粉丝10W,CSDN特邀作者、博客专家、CSDN新星计划导师、全栈领域…...

设计模式~备忘录模式(memento)-22

目录  (1)优点&#xff1a; (2)缺点&#xff1a; (3)使用场景&#xff1a; (4)注意事项&#xff1a; (5)应用实例&#xff1a; 代码 备忘录模式(memento) 备忘录模式&#xff08;Memento Pattern&#xff09;保存一个对象的某个状态&#xff0c;以便在适当的时候恢复对…...

【Agora UID 踩坑记录 Java 数据类型】

目录 负数二进制表示Java中32位无符号数的取法项目踩坑记录Java 0xffffffff隐式类型转换的坑 负数二进制表示 由于计算机中数据都以二进制表示&#xff0c;而负数的二级制是根据正数二进制取补码&#xff08;补码就是先取反码&#xff0c;然后加1&#xff09;得到&#xff0c;…...

ESP8285 RTOS SDK OTA

一、官方资源说明 官方指南&#xff1a;空中升级 (OTA) - ESP32 - — ESP-IDF 编程指南 v4.3.6 文档&#xff0c;虽然是正对ESP32的&#xff0c;但是原理是一样的。 官方参考例程&#xff1a;esp-idf\ESP8266_RTOS_SDK\examples\system\ota\&#xff0c;其中包含两个例程&…...

Hadoop3教程(四):HDFS的读写流程及节点距离计算

文章目录 &#xff08;55&#xff09;HDFS 写数据流程&#xff08;56&#xff09; 节点距离计算&#xff08;57&#xff09;机架感知&#xff08;副本存储节点选择&#xff09;&#xff08;58&#xff09;HDFS 读数据流程参考文献 &#xff08;55&#xff09;HDFS 写数据流程 …...

[0xGameCTF 2023] web题解

文章目录 [Week 1]signinbaby_phphello_httprepo_leakping [Week 2]ez_sqli方法一&#xff08;十六进制绕过&#xff09;方法二&#xff08;字符串拼接&#xff09; ez_upload [Week 1] signin 打开题目&#xff0c;查看下js代码 在main.js里找到flag baby_php <?php /…...

【项目实战】通过多模态+LangGraph实现PPT生成助手

PPT自动生成系统 基于LangGraph的PPT自动生成系统&#xff0c;可以将Markdown文档自动转换为PPT演示文稿。 功能特点 Markdown解析&#xff1a;自动解析Markdown文档结构PPT模板分析&#xff1a;分析PPT模板的布局和风格智能布局决策&#xff1a;匹配内容与合适的PPT布局自动…...

力扣-35.搜索插入位置

题目描述 给定一个排序数组和一个目标值&#xff0c;在数组中找到目标值&#xff0c;并返回其索引。如果目标值不存在于数组中&#xff0c;返回它将会被按顺序插入的位置。 请必须使用时间复杂度为 O(log n) 的算法。 class Solution {public int searchInsert(int[] nums, …...

Selenium常用函数介绍

目录 一&#xff0c;元素定位 1.1 cssSeector 1.2 xpath 二&#xff0c;操作测试对象 三&#xff0c;窗口 3.1 案例 3.2 窗口切换 3.3 窗口大小 3.4 屏幕截图 3.5 关闭窗口 四&#xff0c;弹窗 五&#xff0c;等待 六&#xff0c;导航 七&#xff0c;文件上传 …...

从物理机到云原生:全面解析计算虚拟化技术的演进与应用

前言&#xff1a;我的虚拟化技术探索之旅 我最早接触"虚拟机"的概念是从Java开始的——JVM&#xff08;Java Virtual Machine&#xff09;让"一次编写&#xff0c;到处运行"成为可能。这个软件层面的虚拟化让我着迷&#xff0c;但直到后来接触VMware和Doc…...

算术操作符与类型转换:从基础到精通

目录 前言&#xff1a;从基础到实践——探索运算符与类型转换的奥秘 算术操作符超级详解 算术操作符&#xff1a;、-、*、/、% 赋值操作符&#xff1a;和复合赋值 单⽬操作符&#xff1a;、--、、- 前言&#xff1a;从基础到实践——探索运算符与类型转换的奥秘 在先前的文…...

goreplay

1.github地址 https://github.com/buger/goreplay 2.简单介绍 GoReplay 是一个开源的网络监控工具&#xff0c;可以记录用户的实时流量并将其用于镜像、负载测试、监控和详细分析。 3.出现背景 随着应用程序的增长&#xff0c;测试它所需的工作量也会呈指数级增长。GoRepl…...

高保真组件库:开关

一:制作关状态 拖入一个矩形作为关闭的底色:44 x 22,填充灰色CCCCCC,圆角23,边框宽度0,文本为”关“,右对齐,边距2,2,6,2,文本颜色白色FFFFFF。 拖拽一个椭圆,尺寸18 x 18,边框为0。3. 全选转为动态面板状态1命名为”关“。 二:制作开状态 复制关状态并命名为”开…...

【见合八方平面波导外腔激光器专题系列】用于干涉光纤传感的低噪声平面波导外腔激光器2

----翻译自Mazin Alalus等人的文章 摘要 1550 nm DWDM 平面波导外腔激光器具有低相位/频率噪声、窄线宽和低 RIN 等特点。该腔体包括一个半导体增益芯片和一个带布拉格光栅的平面光波电路波导&#xff0c;采用 14 引脚蝶形封装。这种平面波导外腔激光器设计用于在振动和恶劣的…...

Qt Quick模块功能及架构

Qt 6.0 中的 Qt Quick 模块是构建现代、动态用户界面的核心框架&#xff0c;基于声明式编程&#xff08;QML&#xff09;和 JavaScript&#xff0c;专注于高性能、流畅的动画和跨平台 UI 开发。、 一、主要功能改进 1. Qt Quick 核心架构 QML 引擎升级&#xff1a;Qt 6.0 使用…...

大数据学习(129)-Hive数据分析

&#x1f34b;&#x1f34b;大数据学习&#x1f34b;&#x1f34b; &#x1f525;系列专栏&#xff1a; &#x1f451;哲学语录: 用力所能及&#xff0c;改变世界。 &#x1f496;如果觉得博主的文章还不错的话&#xff0c;请点赞&#x1f44d;收藏⭐️留言&#x1f4dd;支持一…...