当前位置: 首页 > news >正文

pytorch学习(十二):对现有的模型进行修改

以VGG16为例:

VGG((features): Sequential((0): Conv2d(3, 64, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))(1): ReLU(inplace=True)(2): Conv2d(64, 64, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))(3): ReLU(inplace=True)(4): MaxPool2d(kernel_size=2, stride=2, padding=0, dilation=1, ceil_mode=False)(5): Conv2d(64, 128, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))(6): ReLU(inplace=True)(7): Conv2d(128, 128, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))(8): ReLU(inplace=True)(9): MaxPool2d(kernel_size=2, stride=2, padding=0, dilation=1, ceil_mode=False)(10): Conv2d(128, 256, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))(11): ReLU(inplace=True)(12): Conv2d(256, 256, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))(13): ReLU(inplace=True)(14): Conv2d(256, 256, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))(15): ReLU(inplace=True)(16): MaxPool2d(kernel_size=2, stride=2, padding=0, dilation=1, ceil_mode=False)(17): Conv2d(256, 512, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))(18): ReLU(inplace=True)(19): Conv2d(512, 512, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))(20): ReLU(inplace=True)(21): Conv2d(512, 512, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))(22): ReLU(inplace=True)(23): MaxPool2d(kernel_size=2, stride=2, padding=0, dilation=1, ceil_mode=False)(24): Conv2d(512, 512, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))(25): ReLU(inplace=True)(26): Conv2d(512, 512, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))(27): ReLU(inplace=True)(28): Conv2d(512, 512, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))(29): ReLU(inplace=True)(30): MaxPool2d(kernel_size=2, stride=2, padding=0, dilation=1, ceil_mode=False))(avgpool): AdaptiveAvgPool2d(output_size=(7, 7))(classifier): Sequential((0): Linear(in_features=25088, out_features=4096, bias=True)(1): ReLU(inplace=True)(2): Dropout(p=0.5, inplace=False)(3): Linear(in_features=4096, out_features=4096, bias=True)(4): ReLU(inplace=True)(5): Dropout(p=0.5, inplace=False)(6): Linear(in_features=4096, out_features=1000, bias=True))
)

特征提取部分(features

  • 卷积层与ReLU激活:网络的前半部分主要由卷积层(Conv2d)和ReLU激活函数(ReLU)交替组成。每个卷积层后都紧跟一个ReLU层,用于引入非线性。这种结构有助于网络学习复杂的特征表示。

  • 卷积层配置

    • 初始阶段,使用64个3x3的卷积核,然后是ReLU激活,接着是另一个3x3卷积核和ReLU激活,之后是一个2x2的最大池化层(MaxPool2d),用于降低特征图的尺寸并增加感受野。
    • 类似地,这个过程在特征图的通道数增加到128、256和512时重复,每次增加通道数后都会跟随几个卷积层和ReLU激活,然后是一个最大池化层。
    • 值得注意的是,在512通道的部分,卷积层和ReLU激活的组合被重复了三次,而没有立即进行池化,这可能是为了进一步增强特征表示。
  • 最大池化层:用于在每个阶段的末尾减少特征图的尺寸,这有助于减少计算量和参数数量,同时保持重要的特征信息。

全连接层部分(classifier

  • 自适应平均池化:在特征提取部分之后,使用了一个自适应平均池化层(AdaptiveAvgPool2d),将特征图的尺寸调整为7x7。这是为了确保无论输入图像的大小如何,全连接层都能接收到固定大小的输入。

  • 全连接层

    • 第一个全连接层(Linear)将7x7x512的特征图展平为25088个特征,并映射到4096个输出特征上。
    • 接着是两个ReLU激活层、两个Dropout层(用于防止过拟合)和另外两个全连接层,最终输出1000个类别的得分(假设是用于ImageNet分类任务)。

可以看到,最后一层的全连接层的输出是1000,那么当我们有例如十分类的问题时候,就需要对网络进行修改。

 

vgg16_true.add_module('add_linear',nn.Linear(1000,10))

运行上述代码,在末尾加一层线性层,也就是全连接层。

还有一种方式是对原有的全连接层进行修改,将1000改为10。

vgg16_false.classifier[6]=nn.Linear(4096,10)

附上所有源代码;

# -*- coding: utf-8 -*-  
# File created on 2024/8/9 
# 作者:酷尔
# 公众号:酷尔计算机import torchvision
from torch import nn
# train_data=torchvision.datasets.ImageNet('./data_imagenet',split='train',download=True,transform=torchvision.transforms.ToTensor())vgg16_false=torchvision.models.vgg16(pretrained=False)
vgg16_true=torchvision.models.vgg16(pretrained=True)# print(vgg16_true)
# import os
#
# # 尝试从环境变量中获取TORCH_HOME
# torch_home = os.getenv('TORCH_HOME', os.path.expanduser('~/.torch'))
# model_cache_dir = os.path.join(torch_home, 'models')
#
# print(f"Model cache directory: {model_cache_dir}")
# 
# # 注意:这个目录可能不直接包含模型文件,因为 PyTorch 可能使用了内部的缓存机制
# # 来管理这些文件,并且它们可能以哈希名存储而不是直接以模型名存储。train_data=torchvision.datasets.CIFAR10('./dataset',train=True,download=True,transform=torchvision.transforms.ToTensor())
vgg16_true.classifier.add_module('add_linear',nn.Linear(1000,10))
# print(vgg16_true)vgg16_false.classifier[6]=nn.Linear(4096,10)
print(vgg16_false)

相关文章:

pytorch学习(十二):对现有的模型进行修改

以VGG16为例: VGG((features): Sequential((0): Conv2d(3, 64, kernel_size(3, 3), stride(1, 1), padding(1, 1))(1): ReLU(inplaceTrue)(2): Conv2d(64, 64, kernel_size(3, 3), stride(1, 1), padding(1, 1))(3): ReLU(inplaceTrue)(4): MaxPool2d(kernel_size2…...

服务器虚拟内存是什么?虚拟内存怎么设置?

服务器虚拟内存是计算机系统内存管理的一种重要技术,它允许应用程序认为它们拥有连续且完整的内存地址空间,而实际上这些内存空间是由多个物理内存碎片和外部磁盘存储器上的空间共同组成的。当物理内存(RAM)不足时,系统…...

深度学习入门指南(1) - 从chatgpt入手

2012年,加拿大多伦多大学的Hinton教授带领他的两个学生Alex和Ilya一起用AlexNet撞开了深度学习的大门,从此人类走入了深度学习时代。 2015年,这个第二作者80后Ilya Sutskever参与创建了openai公司。现在Ilya是openai的首席科学家,…...

Python学习笔记(六)

""" 演示对序列进行切片操作 """ # 切片;从一个序列中,取出一个子序列 # 语法[起始下标:结束下标:步长] # 这三个都不写也行,视为从头到尾步长为1 # 起始下标不写,视作从头开…...

大数据安全规划总体方案(45页PPT)

方案介绍: 大数据安全规划总体方案的制定,旨在应对当前大数据环境中存在的各类安全风险,包括但不限于数据泄露、数据篡改、非法访问等。通过构建完善的安全防护体系,保障大数据在采集、存储、处理、传输、共享等全生命周期中的安…...

第20周:Pytorch文本分类入门

目录 前言 一、前期准备 1.1 环境安装导入包 1.2 加载数据 1.3 构建词典 1.4 生成数据批次和迭代器 二、准备模型 2.1 定义模型 2.2 定义示例 2.3 定义训练函数与评估函数 三、训练模型 3.1 拆分数据集并运行模型 3.2 使用测试数据集评估模型 总结 前言 &#x1…...

记一次 SpringBoot2.x 配置 Fastjson请求报 internal server 500

1.遇到的问题 报错springboot从2.1.16升级到2.5.15,之后就报500内部错误,后面调用都是正常的,就考虑转换有错。 接口返回错误: 2.解决办法 因为我用了fastjson,需要转换下,目前可能理解就是springboot-we…...

OSPF笔记

OSPF:开放式最短路径优先协议 使用范围:IGP 协议算法特点:链路状态型路由协议,SPF算法 协议是否传递网络掩码:传递网络掩码 协议封装:基于ip协议封装,协议号为89 一,ospf特点 1…...

IOC容器初始化流程

IOC容器初始化流程 一、概要1.准备上下文prepareRefresh()2. 获取beanFactory:obtainFreshBeanFactory()3. 准备beanFactory:prepareBeanFactory(beanFactory)4. 后置处理:postProcessBeanFactory()5. 调用bean工厂后置处理器:invokeBeanFactoryPostProcessors()6. 注册bea…...

第二季度云计算市场份额榜单:微软下滑,谷歌上升,AWS仍保持领先

2024 年第二季度,随着企业云支出达到 790 亿美元的新高,三大云计算巨头微软、谷歌云和 AWS的全球云市场份额发生了变化。 根据新的市场数据,以下是 2024 年第二季度全球云市场份额结果和六大世界领先者,其中包括 AWS、阿里巴巴、…...

三点确定圆心算法推导

已知a,b,c三点求过这三点的圆心坐标 a ( x 1 , y 1 ) a(x_1, y_1) a(x1​,y1​) 、 b ( x 2 , y 2 ) b(x_2, y_2) b(x2​,y2​) 、 c ( x 3 , y 3 ) c(x_3, y_3) c(x3​,y3​) 确认三点是否共线 叉积计算方式 v → ( X 1 , Y 1 ) u → ( X 2 , Y 2 ) X 1 Y 2 − X 2 Y 1 \…...

神经网络 (NN) TensorFlow Playground在线应用程序

神经网络 (NN) 历史上最重要的发现之一是神经网络 (NN) 的强大功能。 在神经网络中,称为神经元的许多数据层被添加在一起或相互堆叠以计算新的数据级别。 常用的简称: DNN 深度神经网络CNN 卷积神经网络RNN 循环神经网络 神经元 科学家一致认为&am…...

腾讯课堂 离线m3u8.sqlite转成视频

为了广大腾讯课堂用户对于购买的课程不能正常离线播放,构成知识付费损失,故出此文档。 重点:完全免费!!!完全免费!!!完全免费!!! 怎么…...

Linux多路转接

文章目录 IO模型多路转接select 和 pollepoll IO模型 在还在学习语言的阶段,C里使用cin,或者是C使用scanf的时候,总是要等着我们输入数据才执行,这种IO是阻塞IO。下面是比较正式的说法。 阻塞IO: 在内核将数据准备好之前&#xf…...

IDEA导入Maven项目的流程配置以常见问题解决

1. 前言 本文主要围绕着在IDEA中导入新Maven项目后的配置及常见问题解决来展开说说。相关的部分软件如下: IntelliJ IDEA 2021.1JDK 1.8Window 2. 导入Maven项目及配置 2.1 导入Maven项目 下面介绍了直接打开本地项目和导入git上的项目两种导入Maven方式。 1…...

【数据分析---- Pandas进阶指南:核心计算方法、缺失值处理及数据类型管理】

前言: 💞💞大家好,我是书生♡,本阶段和大家一起分享和探索数据分析,本篇文章主要讲述了:Pandas进阶指南:核心计算方法、缺失值处理及数据类型管理等等。欢迎大家一起探索讨论&#x…...

2024世界机器人大会将于8月21日至25日在京举行

2024年的世界机器人大会预定于8月21日至25日,在北京经济技术开发区的北人亦创国际会展中心隆重举办。 本届大会以“共育新质生产力 共享智能新未来”为核心主题,将汇聚来自全球超过300位的机器人行业专家、国际组织代表、杰出科学家以及企业家&#xff0…...

【Linux】lvm被删除或者lvm丢失了怎么办

模拟案例 接下来模拟lvm误删除如何恢复的案例: 模拟删除: 查看vg名: vgdisplayvgcfgrestore --list uniontechos #查看之前的操作 例如我删除的,现场没有删除就用最近的操作文件: 还原: vgcfgrestore…...

疫情防控管理系统

摘 要 由于当前疫情防控形势复杂,为做好学校疫情防控管理措施,根据上级防疫部门要求,为了学生的生命安全,要求学校加强疫情防控的管理。为了迎合时代需求,优化管理效率,各种各样的管理系统应运而生&#x…...

永久删除的Android 文件去哪了?在Android上恢复误删除的消息和照片方法?

丢失重要消息和照片可能是一种令人沮丧的经历,尤其是在您没有备份的情况下。但别担心,在本教程中,我们将指导您完成在Android设备上恢复已删除消息和照片的步骤。无论您是不小心删除了它们还是由于软件问题而消失了,这些步骤都可以…...

PHP和Node.js哪个更爽?

先说结论,rust完胜。 php:laravel,swoole,webman,最开始在苏宁的时候写了几年php,当时觉得php真的是世界上最好的语言,因为当初活在舒适圈里,不愿意跳出来,就好比当初活在…...

python/java环境配置

环境变量放一起 python: 1.首先下载Python Python下载地址:Download Python | Python.org downloads ---windows -- 64 2.安装Python 下面两个,然后自定义,全选 可以把前4个选上 3.环境配置 1)搜高级系统设置 2…...

关于nvm与node.js

1 安装nvm 安装过程中手动修改 nvm的安装路径, 以及修改 通过nvm安装node后正在使用的node的存放目录【这句话可能难以理解,但接着往下看你就了然了】 2 修改nvm中settings.txt文件配置 nvm安装成功后,通常在该文件中会出现以下配置&…...

pam_env.so模块配置解析

在PAM(Pluggable Authentication Modules)配置中, /etc/pam.d/su 文件相关配置含义如下: 配置解析 auth required pam_env.so1. 字段分解 字段值说明模块类型auth认证类模块,负责验证用户身份&am…...

JVM垃圾回收机制全解析

Java虚拟机(JVM)中的垃圾收集器(Garbage Collector,简称GC)是用于自动管理内存的机制。它负责识别和清除不再被程序使用的对象,从而释放内存空间,避免内存泄漏和内存溢出等问题。垃圾收集器在Ja…...

【ROS】Nav2源码之nav2_behavior_tree-行为树节点列表

1、行为树节点分类 在 Nav2(Navigation2)的行为树框架中,行为树节点插件按照功能分为 Action(动作节点)、Condition(条件节点)、Control(控制节点) 和 Decorator(装饰节点) 四类。 1.1 动作节点 Action 执行具体的机器人操作或任务,直接与硬件、传感器或外部系统…...

浅谈不同二分算法的查找情况

二分算法原理比较简单,但是实际的算法模板却有很多,这一切都源于二分查找问题中的复杂情况和二分算法的边界处理,以下是博主对一些二分算法查找的情况分析。 需要说明的是,以下二分算法都是基于有序序列为升序有序的情况&#xf…...

docker 部署发现spring.profiles.active 问题

报错: org.springframework.boot.context.config.InvalidConfigDataPropertyException: Property spring.profiles.active imported from location class path resource [application-test.yml] is invalid in a profile specific resource [origin: class path re…...

rnn判断string中第一次出现a的下标

# coding:utf8 import torch import torch.nn as nn import numpy as np import random import json""" 基于pytorch的网络编写 实现一个RNN网络完成多分类任务 判断字符 a 第一次出现在字符串中的位置 """class TorchModel(nn.Module):def __in…...

安宝特方案丨船舶智造的“AR+AI+作业标准化管理解决方案”(装配)

船舶制造装配管理现状:装配工作依赖人工经验,装配工人凭借长期实践积累的操作技巧完成零部件组装。企业通常制定了装配作业指导书,但在实际执行中,工人对指导书的理解和遵循程度参差不齐。 船舶装配过程中的挑战与需求 挑战 (1…...