当前位置: 首页 > news >正文

pytorch复现3_GoogLenet

背景:
GoogLeNeta是2014年提出的一种全新的深度学习结构,在这之前的AlexNet、VGG等结构都是通过增大网络的深度(层数)来获得更好的训练效果,但层数的增加会带来很多负作用,比如overfit、梯度消失、梯度爆炸等。GoogLeNet通过引入inception从另一种角度来提升训练结果:能更高效的利用计算资源,在相同的计算量下能提取到更多的特征,从而提升训练结果。
因此GoogLeNet在专注于加深网络结构的同时,引入了新的基本结构——Inception模块,以增加网络的宽度。

网络结构图:
在这里插入图片描述
1、Inception模块
Inception就是把多个卷积或池化操作,放在一起组装成一个网络模块
在这里插入图片描述
实际中需要什么样的Inception
  我们在上面提供了一种Inception的结构,但是这个结构存在很多问题,是不能够直接使用的。首要问题就是参数太多,导致特征图厚度太大。为了解决这个问题,作者在其中加入了1X1的卷积核,改进后的Inception结构如下图:

  在这里插入图片描述
代码:
定义BasicConv2d

class BasicConv2d(nn.Module):def __init__(self, in_channels, out_channels, **kwargs):super(BasicConv2d, self).__init__()self.conv = nn.Conv2d(in_channels, out_channels, **kwargs)self.relu = nn.ReLU(inplace=True)def forward(self, x):x = self.conv(x)x = self.relu(x)return x

定义Inception

class Inception(nn.Module):def __init__(self, in_channels, ch1x1, ch3x3red, ch3x3, ch5x5red, ch5x5, pool_proj):super(Inception, self).__init__()self.branch1 = BasicConv2d(in_channels, ch1x1, kernel_size=1)self.branch2 = nn.Sequential(BasicConv2d(in_channels, ch3x3red, kernel_size=1),BasicConv2d(ch3x3red, ch3x3, kernel_size=3, padding=1)   # 保证输出大小等于输入大小)self.branch3 = nn.Sequential(BasicConv2d(in_channels, ch5x5red, kernel_size=1),BasicConv2d(ch5x5red, ch5x5, kernel_size=5, padding=2)   # 保证输出大小等于输入大小)self.branch4 = nn.Sequential(nn.MaxPool2d(kernel_size=3, stride=1, padding=1),BasicConv2d(in_channels, pool_proj, kernel_size=1))

定义分类器

class InceptionAux(nn.Module):def __init__(self, in_channels, num_classes):super(InceptionAux, self).__init__()self.averagePool = nn.AvgPool2d(kernel_size=5, stride=3)self.conv = BasicConv2d(in_channels, 128, kernel_size=1)  # output[batch, 128, 4, 4]self.fc1 = nn.Linear(2048, 1024)self.fc2 = nn.Linear(1024, num_classes)def forward(self, x):# aux1: N x 512 x 14 x 14, aux2: N x 528 x 14 x 14x = self.averagePool(x)# aux1: N x 512 x 4 x 4, aux2: N x 528 x 4 x 4x = self.conv(x)# N x 128 x 4 x 4x = torch.flatten(x, 1)x = F.dropout(x, 0.5, training=self.training)# N x 2048x = F.relu(self.fc1(x), inplace=True)x = F.dropout(x, 0.5, training=self.training)# N x 1024x = self.fc2(x)# N x num_classesreturn x

model完整代码:

import torch.nn as nn
import torch
import torch.nn.functional as Fclass GoogLeNet(nn.Module):def __init__(self, num_classes=1000, aux_logits=True, init_weights=False):super(GoogLeNet, self).__init__()self.aux_logits = aux_logitsself.conv1 = BasicConv2d(3, 64, kernel_size=7, stride=2, padding=3)self.maxpool1 = nn.MaxPool2d(3, stride=2, ceil_mode=True)self.conv2 = BasicConv2d(64, 64, kernel_size=1)self.conv3 = BasicConv2d(64, 192, kernel_size=3, padding=1)self.maxpool2 = nn.MaxPool2d(3, stride=2, ceil_mode=True)self.inception3a = Inception(192, 64, 96, 128, 16, 32, 32)self.inception3b = Inception(256, 128, 128, 192, 32, 96, 64)self.maxpool3 = nn.MaxPool2d(3, stride=2, ceil_mode=True)self.inception4a = Inception(480, 192, 96, 208, 16, 48, 64)self.inception4b = Inception(512, 160, 112, 224, 24, 64, 64)self.inception4c = Inception(512, 128, 128, 256, 24, 64, 64)self.inception4d = Inception(512, 112, 144, 288, 32, 64, 64)self.inception4e = Inception(528, 256, 160, 320, 32, 128, 128)self.maxpool4 = nn.MaxPool2d(3, stride=2, ceil_mode=True)self.inception5a = Inception(832, 256, 160, 320, 32, 128, 128)self.inception5b = Inception(832, 384, 192, 384, 48, 128, 128)if self.aux_logits:self.aux1 = InceptionAux(512, num_classes)self.aux2 = InceptionAux(528, num_classes)self.avgpool = nn.AdaptiveAvgPool2d((1, 1))self.dropout = nn.Dropout(0.4)self.fc = nn.Linear(1024, num_classes)if init_weights:self._initialize_weights()def forward(self, x):# N x 3 x 224 x 224x = self.conv1(x)# N x 64 x 112 x 112x = self.maxpool1(x)# N x 64 x 56 x 56x = self.conv2(x)# N x 64 x 56 x 56x = self.conv3(x)# N x 192 x 56 x 56x = self.maxpool2(x)# N x 192 x 28 x 28x = self.inception3a(x)# N x 256 x 28 x 28x = self.inception3b(x)# N x 480 x 28 x 28x = self.maxpool3(x)# N x 480 x 14 x 14x = self.inception4a(x)# N x 512 x 14 x 14if self.training and self.aux_logits:    # eval model lose this layeraux1 = self.aux1(x)x = self.inception4b(x)# N x 512 x 14 x 14x = self.inception4c(x)# N x 512 x 14 x 14x = self.inception4d(x)# N x 528 x 14 x 14if self.training and self.aux_logits:    # eval model lose this layeraux2 = self.aux2(x)x = self.inception4e(x)# N x 832 x 14 x 14x = self.maxpool4(x)# N x 832 x 7 x 7x = self.inception5a(x)# N x 832 x 7 x 7x = self.inception5b(x)# N x 1024 x 7 x 7x = self.avgpool(x)# N x 1024 x 1 x 1x = torch.flatten(x, 1)# N x 1024x = self.dropout(x)x = self.fc(x)# N x 1000 (num_classes)if self.training and self.aux_logits:   # eval model lose this layerreturn x, aux2, aux1return xdef _initialize_weights(self):for m in self.modules():if isinstance(m, nn.Conv2d):nn.init.kaiming_normal_(m.weight, mode='fan_out', nonlinearity='relu')if m.bias is not None:nn.init.constant_(m.bias, 0)elif isinstance(m, nn.Linear):nn.init.normal_(m.weight, 0, 0.01)nn.init.constant_(m.bias, 0)class Inception(nn.Module):def __init__(self, in_channels, ch1x1, ch3x3red, ch3x3, ch5x5red, ch5x5, pool_proj):super(Inception, self).__init__()self.branch1 = BasicConv2d(in_channels, ch1x1, kernel_size=1)self.branch2 = nn.Sequential(BasicConv2d(in_channels, ch3x3red, kernel_size=1),BasicConv2d(ch3x3red, ch3x3, kernel_size=3, padding=1)   # 保证输出大小等于输入大小)self.branch3 = nn.Sequential(BasicConv2d(in_channels, ch5x5red, kernel_size=1),BasicConv2d(ch5x5red, ch5x5, kernel_size=5, padding=2)   # 保证输出大小等于输入大小)self.branch4 = nn.Sequential(nn.MaxPool2d(kernel_size=3, stride=1, padding=1),BasicConv2d(in_channels, pool_proj, kernel_size=1))def forward(self, x):branch1 = self.branch1(x)branch2 = self.branch2(x)branch3 = self.branch3(x)branch4 = self.branch4(x)outputs = [branch1, branch2, branch3, branch4]return torch.cat(outputs, 1)class InceptionAux(nn.Module):def __init__(self, in_channels, num_classes):super(InceptionAux, self).__init__()self.averagePool = nn.AvgPool2d(kernel_size=5, stride=3)self.conv = BasicConv2d(in_channels, 128, kernel_size=1)  # output[batch, 128, 4, 4]self.fc1 = nn.Linear(2048, 1024)self.fc2 = nn.Linear(1024, num_classes)def forward(self, x):# aux1: N x 512 x 14 x 14, aux2: N x 528 x 14 x 14x = self.averagePool(x)# aux1: N x 512 x 4 x 4, aux2: N x 528 x 4 x 4x = self.conv(x)# N x 128 x 4 x 4x = torch.flatten(x, 1)x = F.dropout(x, 0.5, training=self.training)# N x 2048x = F.relu(self.fc1(x), inplace=True)x = F.dropout(x, 0.5, training=self.training)# N x 1024x = self.fc2(x)# N x num_classesreturn xclass BasicConv2d(nn.Module):def __init__(self, in_channels, out_channels, **kwargs):super(BasicConv2d, self).__init__()self.conv = nn.Conv2d(in_channels, out_channels, **kwargs)self.relu = nn.ReLU(inplace=True)def forward(self, x):x = self.conv(x)x = self.relu(x)return x

相关文章:

pytorch复现3_GoogLenet

背景: GoogLeNeta是2014年提出的一种全新的深度学习结构,在这之前的AlexNet、VGG等结构都是通过增大网络的深度(层数)来获得更好的训练效果,但层数的增加会带来很多负作用,比如overfit、梯度消失、梯度爆炸等。GoogLeNet通过引入i…...

CH09_重新组织数据

拆分变量(Split Variable) 曾用名:移除对参数的赋值(Remove Assignments to Parameters) 曾用名:分解临时变量(Split Temp) let temp 2 * (height width); console.log(temp); t…...

最新 IntelliJ IDEA 旗舰版和社区版下载安装教程(图解)

🌷🍁 博主猫头虎(🐅🐾)带您 Go to New World✨🍁 🦄 博客首页——🐅🐾猫头虎的博客🎐 🐳 《面试题大全专栏》 🦕 文章图文…...

优化 FPGA HLS 设计

优化 FPGA HLS 设计 用工具用 C 生成 RTL 的代码基本不可读。以下是如何在不更改任何 RTL 的情况下提高设计性能。 介绍 高级设计能够以简洁的方式捕获设计,从而减少错误并更容易调试。然而,经常出现的问题是性能权衡。在高度复杂的 FPGA 设计中实现高性…...

LVGL库入门 01 - 样式

一、LVGL样式概述 1、创建样式 在 LVGL 中,样式都是以对象的方式存在,一个对象可以描述一种样式。每个控件都可以独立添加样式,创建的样式之间互不影响。 可以使用 lv_style_t 类型创建一个样式并初始化: static lv_style_t s…...

酷克数据出席永洪科技用户大会 携手驱动商业智能升级

10月27日,第7届永洪科技全国用户大会在北京召开。酷克数据作为国内云原生数仓代表企业,受邀出席本次大会,全面展示了云数仓领域最新前沿技术,并进行主题演讲。 携手合作 助力企业释放数据价值 数据仓库是商业智能(BI…...

英语教育目标转变:更加注重实际应用能力培养

今年九月份,北京市教委发布了《关于深入推进高中阶段学校考试招生改革的实施意见》。按照该意见,北京市2024年初三年级学生的初中学业水平考试英语科目听力口语考试与笔试将分离,首次计算机考试将于2023年12月17日进行。 根据《意见》规定,听力口语计算机考试共有两次考试机会…...

Java中的继承和多态

目录 1. 继承 1.1 为什么需要继承 1.2 继承概念 1.3 继承的语法 1.4 父类成员访问 1.4.1 子类中访问父类的成员变量 1.4.2 子类中访问父类的成员方法 1.5 super关键字 1.6 子类构造方法 1.7 super和this 1.8 再谈初始化 1.9 protected 关键字 1.10 继承方式…...

海外问卷调查现在还可以做吗?

可以做,海外问卷调查是一个稳定长期的互联网创业项目。 大家好,我是橙河,这篇文章讲一讲海外问卷调查现在还可以做吗? 海外问卷调查,简单来说,就是外国的商业公司对外发放的付费调查问卷,按照…...

CA证书与服务器证书

服务器证书和CA证书是网络通信中使用的两种重要的证书。服务器证书是用于验证服务器身份的证书,而CA证书是用于验证证书颁发机构(Certificate Authority)身份的证书。 服务器证书是由网站服务器申请并由CA机构颁发的。它包含了服务器的公钥和其他相关信息&#xff…...

AI智能语音识别模块(二)——基于Arduino的语音控制MP3播放器

文章目录 简介离线语音控制模块Mini MP3模块0.96寸 OLED模块实验准备安装库接线定义主要程序实验效果注意事项总结 简介 在前面一篇文章里我们对AI智能语音识别模块进行了介绍,并对离线语音模组下载固件的过程进行了一个简单描述,不知道大家还记不记得&…...

CentOS部署Minikube

基本介绍 Minikube是本地的Kubernetes,专注于使其易于为Kubernete学习和开发。 官方地址:https://minikube.sigs.k8s.io/docs/start/ 部署安装 # CentOS 7.6# 前置条件:安装好Docker或其他容器引擎或虚拟机 参见《CentOS一键部署Docker》…...

第5章_排序与分页

文章目录 1 排序数据1.1 排序规则1.2 单列排序1.3 多列排序排序演示代码 2 分页2.1 背景2.2 实现规则2.3 拓展分页演示代码 3 课后练习 1 排序数据 1.1 排序规则 使用 ORDER BY 子句排序 ASC(ascend): 升序DESC(descend):降序 …...

Elasticsearch实战:常见错误及详细解决方案

Elasticsearch实战:常见错误及详细解决方案 1.read_only_allow_delete":“true” 当我们在向某个索引添加一条数据的时候,可能(极少情况)会碰到下面的报错: {"error": {"root_cause": [{&…...

C#添加缓存,删除缓存,修改缓存

在C#中,可以使用内置的缓存功能或者使用第三方缓存库来管理缓存。下面分别介绍使用内置缓存功能和使用第三方缓存库的方法。 使用内置缓存功能: 添加缓存: 在C#中,可以使用MemoryCache类来添加缓存。以下是一个简单的示例&…...

PADS Router的操作页面及鼠标指令介绍

PADS Router的用户界面由菜单栏,工作界面,一般工具栏,状态栏,项目浏览器组,输出窗口,电子表格组成(图1): 图1 注意:如果你的界面没有显示项目浏览器&#xff…...

Android studio进入手机调试状态

首先usb插入电脑手机打开开发者模式进入点击就会在你的页面显示了...

《Pytorch新手入门》第二节-动手搭建神经网络

《Pytorch新手入门》第二节-动手搭建神经网络 一、神经网络介绍二、使用torch.nn搭建神经网络2.1 定义网络2.2 torch.autograd.Variable2.3 损失函数与反向传播2.4 优化器torch.optim 三、实战-实现图像分类(CIFAR-10数据集)3.1 CIFAR-10数据集加载与预处理3.2 定义网络结构3.3…...

C++ 模板学习笔记

C另外一种编程成为 泛型编程 ,主要利用的技术就是模板 C提供两种模板机制:函数模板和类模板 C11中,函数模板和类模板都可以设定默认参数,传送门 函数模板 一般 typename 和 class 没有区别,typename 还有个作用是使…...

1、Flink基础概念

1、基础知识 (1)、数据流上的有状态计算 (2)、框架和分布式处理引擎,用于对无界和有界数据流进行有状态计算。 (3)、事件驱动型应用,有数据流就进行处理,无数据流就不…...

分享一下怎么做小程序营销活动

小程序营销活动已经成为现代营销的必备利器,它能够帮助企业提高品牌知名度、促进产品销售,以及加强与用户的互动。然而,要想成功地策划和执行一个小程序营销活动,需要精心设计和全面规划。本文将为您介绍小程序营销活动的策划和执…...

Laravel 后台管理 Dcat Admin 使用记录

Laravel Dcat Admin 安装配置修改配置表格操作 Ajax 结合 Pjax 更新数据状态表格 链接表单设置页面(通常修改更新在同一页面)表单 安装配置 安装文档地址 框架版本 Laravel 8.* 修改配置 修改 admin.php 文件 return [// 后台名称name > DAD后台管理,// 标题title > 后台…...

c语言基础:L1-070 吃火锅

以上图片来自微信朋友圈:这种天气你有什么破事打电话给我基本没用。但是如果你说“吃火锅”,那就厉害了,我们的故事就开始了。 本题要求你实现一个程序,自动检查你朋友给你发来的信息里有没有 chi1 huo3 guo1。 输入格式&#x…...

java spring boot 注解、接口和问题解决方法(持续更新)

注解 RestController 是SpringMVC框架中的一个注解,它结合了Controller和ResponseBody两个注解的功能,用于标记一个类或者方法,表示该类或方法用于处理HTTP请求,并将响应的结果直接返回给客户端,而不需要进行视图渲染…...

HMAC_SHA1加密算法和SHA1加密算法的区别

HMAC_SHA1加密算法和SHA1加密算法的区别 应用场景:SHA1目前主要用于政府部门和私营业主处理敏感信息,被视为最先进的加密技术。而HMAC_SHA1主要用于基于密钥的消息认证码(HMAC)运算,需要一个密钥作为输入。密钥需求&a…...

Ubuntu连不上WiFi 或者虽然能连上校园网,但是浏览器打不开登录页面

写在前面 自己的电脑环境: Ubuntu20.04 一、问题描述 自己的 Ubuntu 遇到连接不上 除校园网之外的其他WiFi, 或者 虽然能连上校园网,但是浏览器打不开登录页面的问题。 二、解决方法 出现这种问题的原因可能是 之前开过VPN, 导致系统的网络设置出现…...

Maven第八章:如何解决Maven的jar版本冲突

Maven第八章:如何解决Maven的jar版本冲突 前言 本文重点讲解Maven依赖冲突原因,maven依赖原则以及如何利用idea Maven Helper插件分析解决问题。 背景 开发过程中引入第三方jar遇到依赖冲突的,非常影响开发,甚至大部分时间都在调试版本兼容。 Caused by:java.lang.NoSuch…...

c# 读写内存映射文件

在C#中,可以使用System.IO.MemoryMappedFiles命名空间中的MemoryMappedFile类来操作内存映射文件。可以创建不固定大小的内存映射文件,具体步骤如下: 1. 先创建一个初始大小为0的内存映射文件,使用MemoryMappedFile.CreateNew方法…...

行业揭秘:腾讯共享wifi码推广零加盟费是真的吗?

近年来,“共享经济”概念在商业领域取得了巨大成功。共享WiFi贴码成为共享经济的一种典型案例,被越来越多的人看作是一种低风险、高回报的投资方式。而在这个市场中,腾讯WiFi码推广以“零加盟费”而备受关注。本文将探讨腾讯WiFi码推广零加盟…...

E4980A 精密型 LCR 表,20 Hz 至 2 MHz

E4980A 精密型 LCR 表 20 Hz 至 2 MHz E4980A 精密型 LCR 表实现了测量准确度、速度与通用性的理想结合,适用于各种元器件测量。 E4980A 精密型 LCR 表实现了测量准确度、速度与通用性的理想结合,适用于各种元器件测量。 无论是在低阻抗量程还是在高阻…...