当前位置: 首页 > news >正文

pytorch学习(十二):对现有的模型进行修改

以VGG16为例:

VGG((features): Sequential((0): Conv2d(3, 64, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))(1): ReLU(inplace=True)(2): Conv2d(64, 64, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))(3): ReLU(inplace=True)(4): MaxPool2d(kernel_size=2, stride=2, padding=0, dilation=1, ceil_mode=False)(5): Conv2d(64, 128, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))(6): ReLU(inplace=True)(7): Conv2d(128, 128, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))(8): ReLU(inplace=True)(9): MaxPool2d(kernel_size=2, stride=2, padding=0, dilation=1, ceil_mode=False)(10): Conv2d(128, 256, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))(11): ReLU(inplace=True)(12): Conv2d(256, 256, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))(13): ReLU(inplace=True)(14): Conv2d(256, 256, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))(15): ReLU(inplace=True)(16): MaxPool2d(kernel_size=2, stride=2, padding=0, dilation=1, ceil_mode=False)(17): Conv2d(256, 512, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))(18): ReLU(inplace=True)(19): Conv2d(512, 512, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))(20): ReLU(inplace=True)(21): Conv2d(512, 512, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))(22): ReLU(inplace=True)(23): MaxPool2d(kernel_size=2, stride=2, padding=0, dilation=1, ceil_mode=False)(24): Conv2d(512, 512, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))(25): ReLU(inplace=True)(26): Conv2d(512, 512, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))(27): ReLU(inplace=True)(28): Conv2d(512, 512, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))(29): ReLU(inplace=True)(30): MaxPool2d(kernel_size=2, stride=2, padding=0, dilation=1, ceil_mode=False))(avgpool): AdaptiveAvgPool2d(output_size=(7, 7))(classifier): Sequential((0): Linear(in_features=25088, out_features=4096, bias=True)(1): ReLU(inplace=True)(2): Dropout(p=0.5, inplace=False)(3): Linear(in_features=4096, out_features=4096, bias=True)(4): ReLU(inplace=True)(5): Dropout(p=0.5, inplace=False)(6): Linear(in_features=4096, out_features=1000, bias=True))
)

特征提取部分(features

  • 卷积层与ReLU激活:网络的前半部分主要由卷积层(Conv2d)和ReLU激活函数(ReLU)交替组成。每个卷积层后都紧跟一个ReLU层,用于引入非线性。这种结构有助于网络学习复杂的特征表示。

  • 卷积层配置

    • 初始阶段,使用64个3x3的卷积核,然后是ReLU激活,接着是另一个3x3卷积核和ReLU激活,之后是一个2x2的最大池化层(MaxPool2d),用于降低特征图的尺寸并增加感受野。
    • 类似地,这个过程在特征图的通道数增加到128、256和512时重复,每次增加通道数后都会跟随几个卷积层和ReLU激活,然后是一个最大池化层。
    • 值得注意的是,在512通道的部分,卷积层和ReLU激活的组合被重复了三次,而没有立即进行池化,这可能是为了进一步增强特征表示。
  • 最大池化层:用于在每个阶段的末尾减少特征图的尺寸,这有助于减少计算量和参数数量,同时保持重要的特征信息。

全连接层部分(classifier

  • 自适应平均池化:在特征提取部分之后,使用了一个自适应平均池化层(AdaptiveAvgPool2d),将特征图的尺寸调整为7x7。这是为了确保无论输入图像的大小如何,全连接层都能接收到固定大小的输入。

  • 全连接层

    • 第一个全连接层(Linear)将7x7x512的特征图展平为25088个特征,并映射到4096个输出特征上。
    • 接着是两个ReLU激活层、两个Dropout层(用于防止过拟合)和另外两个全连接层,最终输出1000个类别的得分(假设是用于ImageNet分类任务)。

可以看到,最后一层的全连接层的输出是1000,那么当我们有例如十分类的问题时候,就需要对网络进行修改。

 

vgg16_true.add_module('add_linear',nn.Linear(1000,10))

运行上述代码,在末尾加一层线性层,也就是全连接层。

还有一种方式是对原有的全连接层进行修改,将1000改为10。

vgg16_false.classifier[6]=nn.Linear(4096,10)

附上所有源代码;

# -*- coding: utf-8 -*-  
# File created on 2024/8/9 
# 作者:酷尔
# 公众号:酷尔计算机import torchvision
from torch import nn
# train_data=torchvision.datasets.ImageNet('./data_imagenet',split='train',download=True,transform=torchvision.transforms.ToTensor())vgg16_false=torchvision.models.vgg16(pretrained=False)
vgg16_true=torchvision.models.vgg16(pretrained=True)# print(vgg16_true)
# import os
#
# # 尝试从环境变量中获取TORCH_HOME
# torch_home = os.getenv('TORCH_HOME', os.path.expanduser('~/.torch'))
# model_cache_dir = os.path.join(torch_home, 'models')
#
# print(f"Model cache directory: {model_cache_dir}")
# 
# # 注意:这个目录可能不直接包含模型文件,因为 PyTorch 可能使用了内部的缓存机制
# # 来管理这些文件,并且它们可能以哈希名存储而不是直接以模型名存储。train_data=torchvision.datasets.CIFAR10('./dataset',train=True,download=True,transform=torchvision.transforms.ToTensor())
vgg16_true.classifier.add_module('add_linear',nn.Linear(1000,10))
# print(vgg16_true)vgg16_false.classifier[6]=nn.Linear(4096,10)
print(vgg16_false)

相关文章:

pytorch学习(十二):对现有的模型进行修改

以VGG16为例: VGG((features): Sequential((0): Conv2d(3, 64, kernel_size(3, 3), stride(1, 1), padding(1, 1))(1): ReLU(inplaceTrue)(2): Conv2d(64, 64, kernel_size(3, 3), stride(1, 1), padding(1, 1))(3): ReLU(inplaceTrue)(4): MaxPool2d(kernel_size2…...

服务器虚拟内存是什么?虚拟内存怎么设置?

服务器虚拟内存是计算机系统内存管理的一种重要技术,它允许应用程序认为它们拥有连续且完整的内存地址空间,而实际上这些内存空间是由多个物理内存碎片和外部磁盘存储器上的空间共同组成的。当物理内存(RAM)不足时,系统…...

深度学习入门指南(1) - 从chatgpt入手

2012年,加拿大多伦多大学的Hinton教授带领他的两个学生Alex和Ilya一起用AlexNet撞开了深度学习的大门,从此人类走入了深度学习时代。 2015年,这个第二作者80后Ilya Sutskever参与创建了openai公司。现在Ilya是openai的首席科学家,…...

Python学习笔记(六)

""" 演示对序列进行切片操作 """ # 切片;从一个序列中,取出一个子序列 # 语法[起始下标:结束下标:步长] # 这三个都不写也行,视为从头到尾步长为1 # 起始下标不写,视作从头开…...

大数据安全规划总体方案(45页PPT)

方案介绍: 大数据安全规划总体方案的制定,旨在应对当前大数据环境中存在的各类安全风险,包括但不限于数据泄露、数据篡改、非法访问等。通过构建完善的安全防护体系,保障大数据在采集、存储、处理、传输、共享等全生命周期中的安…...

第20周:Pytorch文本分类入门

目录 前言 一、前期准备 1.1 环境安装导入包 1.2 加载数据 1.3 构建词典 1.4 生成数据批次和迭代器 二、准备模型 2.1 定义模型 2.2 定义示例 2.3 定义训练函数与评估函数 三、训练模型 3.1 拆分数据集并运行模型 3.2 使用测试数据集评估模型 总结 前言 &#x1…...

记一次 SpringBoot2.x 配置 Fastjson请求报 internal server 500

1.遇到的问题 报错springboot从2.1.16升级到2.5.15,之后就报500内部错误,后面调用都是正常的,就考虑转换有错。 接口返回错误: 2.解决办法 因为我用了fastjson,需要转换下,目前可能理解就是springboot-we…...

OSPF笔记

OSPF:开放式最短路径优先协议 使用范围:IGP 协议算法特点:链路状态型路由协议,SPF算法 协议是否传递网络掩码:传递网络掩码 协议封装:基于ip协议封装,协议号为89 一,ospf特点 1…...

IOC容器初始化流程

IOC容器初始化流程 一、概要1.准备上下文prepareRefresh()2. 获取beanFactory:obtainFreshBeanFactory()3. 准备beanFactory:prepareBeanFactory(beanFactory)4. 后置处理:postProcessBeanFactory()5. 调用bean工厂后置处理器:invokeBeanFactoryPostProcessors()6. 注册bea…...

第二季度云计算市场份额榜单:微软下滑,谷歌上升,AWS仍保持领先

2024 年第二季度,随着企业云支出达到 790 亿美元的新高,三大云计算巨头微软、谷歌云和 AWS的全球云市场份额发生了变化。 根据新的市场数据,以下是 2024 年第二季度全球云市场份额结果和六大世界领先者,其中包括 AWS、阿里巴巴、…...

三点确定圆心算法推导

已知a,b,c三点求过这三点的圆心坐标 a ( x 1 , y 1 ) a(x_1, y_1) a(x1​,y1​) 、 b ( x 2 , y 2 ) b(x_2, y_2) b(x2​,y2​) 、 c ( x 3 , y 3 ) c(x_3, y_3) c(x3​,y3​) 确认三点是否共线 叉积计算方式 v → ( X 1 , Y 1 ) u → ( X 2 , Y 2 ) X 1 Y 2 − X 2 Y 1 \…...

神经网络 (NN) TensorFlow Playground在线应用程序

神经网络 (NN) 历史上最重要的发现之一是神经网络 (NN) 的强大功能。 在神经网络中,称为神经元的许多数据层被添加在一起或相互堆叠以计算新的数据级别。 常用的简称: DNN 深度神经网络CNN 卷积神经网络RNN 循环神经网络 神经元 科学家一致认为&am…...

腾讯课堂 离线m3u8.sqlite转成视频

为了广大腾讯课堂用户对于购买的课程不能正常离线播放,构成知识付费损失,故出此文档。 重点:完全免费!!!完全免费!!!完全免费!!! 怎么…...

Linux多路转接

文章目录 IO模型多路转接select 和 pollepoll IO模型 在还在学习语言的阶段,C里使用cin,或者是C使用scanf的时候,总是要等着我们输入数据才执行,这种IO是阻塞IO。下面是比较正式的说法。 阻塞IO: 在内核将数据准备好之前&#xf…...

IDEA导入Maven项目的流程配置以常见问题解决

1. 前言 本文主要围绕着在IDEA中导入新Maven项目后的配置及常见问题解决来展开说说。相关的部分软件如下: IntelliJ IDEA 2021.1JDK 1.8Window 2. 导入Maven项目及配置 2.1 导入Maven项目 下面介绍了直接打开本地项目和导入git上的项目两种导入Maven方式。 1…...

【数据分析---- Pandas进阶指南:核心计算方法、缺失值处理及数据类型管理】

前言: 💞💞大家好,我是书生♡,本阶段和大家一起分享和探索数据分析,本篇文章主要讲述了:Pandas进阶指南:核心计算方法、缺失值处理及数据类型管理等等。欢迎大家一起探索讨论&#x…...

2024世界机器人大会将于8月21日至25日在京举行

2024年的世界机器人大会预定于8月21日至25日,在北京经济技术开发区的北人亦创国际会展中心隆重举办。 本届大会以“共育新质生产力 共享智能新未来”为核心主题,将汇聚来自全球超过300位的机器人行业专家、国际组织代表、杰出科学家以及企业家&#xff0…...

【Linux】lvm被删除或者lvm丢失了怎么办

模拟案例 接下来模拟lvm误删除如何恢复的案例: 模拟删除: 查看vg名: vgdisplayvgcfgrestore --list uniontechos #查看之前的操作 例如我删除的,现场没有删除就用最近的操作文件: 还原: vgcfgrestore…...

疫情防控管理系统

摘 要 由于当前疫情防控形势复杂,为做好学校疫情防控管理措施,根据上级防疫部门要求,为了学生的生命安全,要求学校加强疫情防控的管理。为了迎合时代需求,优化管理效率,各种各样的管理系统应运而生&#x…...

永久删除的Android 文件去哪了?在Android上恢复误删除的消息和照片方法?

丢失重要消息和照片可能是一种令人沮丧的经历,尤其是在您没有备份的情况下。但别担心,在本教程中,我们将指导您完成在Android设备上恢复已删除消息和照片的步骤。无论您是不小心删除了它们还是由于软件问题而消失了,这些步骤都可以…...

别再乱用分支了!Flowable四种网关(排他/并行/包容/事件)实战选型指南

Flowable四大网关实战选型:从混乱到精准的决策艺术当你在设计一个请假审批流程时,是否遇到过这样的困惑:部门经理审批后需要同时通知HR和财务,但某些特殊情况下又需要跳过财务直接归档?这种看似简单的业务需求&#xf…...

DeepSeek基准测试避坑手册:92%开发者忽略的4大陷阱——硬件配置偏差、tokenizer不一致、batch size幻觉、温度值污染

更多请点击: https://codechina.net 第一章:DeepSeek基准测试避坑手册:92%开发者忽略的4大陷阱——硬件配置偏差、tokenizer不一致、batch size幻觉、温度值污染 硬件配置偏差:GPU显存与计算精度的隐性干扰 在A100(8…...

《我看见的世界:李飞飞自传》第1-6章阅读笔记:从移民少女到AI教母的“看见“之旅

前言 当我们谈论人工智能时,我们谈论的是算法、数据、算力,是那些冰冷的代码和复杂的模型。但在《我看见的世界:李飞飞自传》中,李飞飞用她独特的视角告诉我们:AI的本质,是人类对"看见"世界的渴望…...

大佬推荐的网络安全学习路线(从基础到高级,超级详细)

大佬推荐的网络安全学习路线(从基础到高级,超级详细) 说起网络安全,你可能会担心它是一个过时的行业。有人说,网络安全快卷死了,你既要攻又要防,并且随着技术的发展,你还要不断地学…...

如何快速集成 react-native-bottom-sheet-behavior:5 分钟搞定 Android 底部弹窗

如何快速集成 react-native-bottom-sheet-behavior:5 分钟搞定 Android 底部弹窗 【免费下载链接】react-native-bottom-sheet-behavior react-native wrapper for android BottomSheetBehavior 项目地址: https://gitcode.com/gh_mirrors/re/react-native-bottom…...

Graphin高级应用:结合GISDK构建配置化图分析模块的完整指南

Graphin高级应用:结合GISDK构建配置化图分析模块的完整指南 【免费下载链接】Graphin 🌌 A React toolkit for graph visualization based on G6. 项目地址: https://gitcode.com/gh_mirrors/gr/Graphin 在当今数据驱动的时代,图可视化…...

Taotoken的审计日志功能为企业API安全与合规管理提供支持

🚀 告别海外账号与网络限制!稳定直连全球优质大模型,限时半价接入中。 👉 点击领取海量免费额度 Taotoken的审计日志功能为企业API安全与合规管理提供支持 当企业决定将大模型能力集成到内部业务流程中时,IT管理员和安…...

MeloTTS实战指南:解决多语言TTS部署中的核心挑战

MeloTTS实战指南:解决多语言TTS部署中的核心挑战 【免费下载链接】MeloTTS High-quality multi-lingual text-to-speech library by MyShell.ai. Support English, Spanish, French, Chinese, Japanese and Korean. 项目地址: https://gitcode.com/GitHub_Trendin…...

3大技术突破:重新定义Switch游戏安装性能极限

3大技术突破:重新定义Switch游戏安装性能极限 【免费下载链接】Awoo-Installer A No-Bullshit NSP, NSZ, XCI, and XCZ Installer for Nintendo Switch 项目地址: https://gitcode.com/gh_mirrors/aw/Awoo-Installer Awoo Installer是一款专为破解版Nintendo…...

终极虚拟显示器解决方案:ParsecVDisplay完整使用指南

终极虚拟显示器解决方案:ParsecVDisplay完整使用指南 【免费下载链接】parsec-vdd ✨ Perfect virtual display for game streaming 项目地址: https://gitcode.com/gh_mirrors/pa/parsec-vdd ParsecVDisplay是一个基于Parsec虚拟显示驱动(VDD)的独立应用程序…...