pytorch学习(十二):对现有的模型进行修改
以VGG16为例:
VGG((features): Sequential((0): Conv2d(3, 64, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))(1): ReLU(inplace=True)(2): Conv2d(64, 64, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))(3): ReLU(inplace=True)(4): MaxPool2d(kernel_size=2, stride=2, padding=0, dilation=1, ceil_mode=False)(5): Conv2d(64, 128, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))(6): ReLU(inplace=True)(7): Conv2d(128, 128, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))(8): ReLU(inplace=True)(9): MaxPool2d(kernel_size=2, stride=2, padding=0, dilation=1, ceil_mode=False)(10): Conv2d(128, 256, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))(11): ReLU(inplace=True)(12): Conv2d(256, 256, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))(13): ReLU(inplace=True)(14): Conv2d(256, 256, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))(15): ReLU(inplace=True)(16): MaxPool2d(kernel_size=2, stride=2, padding=0, dilation=1, ceil_mode=False)(17): Conv2d(256, 512, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))(18): ReLU(inplace=True)(19): Conv2d(512, 512, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))(20): ReLU(inplace=True)(21): Conv2d(512, 512, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))(22): ReLU(inplace=True)(23): MaxPool2d(kernel_size=2, stride=2, padding=0, dilation=1, ceil_mode=False)(24): Conv2d(512, 512, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))(25): ReLU(inplace=True)(26): Conv2d(512, 512, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))(27): ReLU(inplace=True)(28): Conv2d(512, 512, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))(29): ReLU(inplace=True)(30): MaxPool2d(kernel_size=2, stride=2, padding=0, dilation=1, ceil_mode=False))(avgpool): AdaptiveAvgPool2d(output_size=(7, 7))(classifier): Sequential((0): Linear(in_features=25088, out_features=4096, bias=True)(1): ReLU(inplace=True)(2): Dropout(p=0.5, inplace=False)(3): Linear(in_features=4096, out_features=4096, bias=True)(4): ReLU(inplace=True)(5): Dropout(p=0.5, inplace=False)(6): Linear(in_features=4096, out_features=1000, bias=True))
)
特征提取部分(features)
-
卷积层与ReLU激活:网络的前半部分主要由卷积层(
Conv2d)和ReLU激活函数(ReLU)交替组成。每个卷积层后都紧跟一个ReLU层,用于引入非线性。这种结构有助于网络学习复杂的特征表示。 -
卷积层配置:
- 初始阶段,使用64个3x3的卷积核,然后是ReLU激活,接着是另一个3x3卷积核和ReLU激活,之后是一个2x2的最大池化层(
MaxPool2d),用于降低特征图的尺寸并增加感受野。 - 类似地,这个过程在特征图的通道数增加到128、256和512时重复,每次增加通道数后都会跟随几个卷积层和ReLU激活,然后是一个最大池化层。
- 值得注意的是,在512通道的部分,卷积层和ReLU激活的组合被重复了三次,而没有立即进行池化,这可能是为了进一步增强特征表示。
- 初始阶段,使用64个3x3的卷积核,然后是ReLU激活,接着是另一个3x3卷积核和ReLU激活,之后是一个2x2的最大池化层(
-
最大池化层:用于在每个阶段的末尾减少特征图的尺寸,这有助于减少计算量和参数数量,同时保持重要的特征信息。
全连接层部分(classifier)
-
自适应平均池化:在特征提取部分之后,使用了一个自适应平均池化层(
AdaptiveAvgPool2d),将特征图的尺寸调整为7x7。这是为了确保无论输入图像的大小如何,全连接层都能接收到固定大小的输入。 -
全连接层:
- 第一个全连接层(
Linear)将7x7x512的特征图展平为25088个特征,并映射到4096个输出特征上。 - 接着是两个ReLU激活层、两个Dropout层(用于防止过拟合)和另外两个全连接层,最终输出1000个类别的得分(假设是用于ImageNet分类任务)。
- 第一个全连接层(
可以看到,最后一层的全连接层的输出是1000,那么当我们有例如十分类的问题时候,就需要对网络进行修改。
vgg16_true.add_module('add_linear',nn.Linear(1000,10))
运行上述代码,在末尾加一层线性层,也就是全连接层。
还有一种方式是对原有的全连接层进行修改,将1000改为10。
vgg16_false.classifier[6]=nn.Linear(4096,10)
附上所有源代码;
# -*- coding: utf-8 -*-
# File created on 2024/8/9
# 作者:酷尔
# 公众号:酷尔计算机import torchvision
from torch import nn
# train_data=torchvision.datasets.ImageNet('./data_imagenet',split='train',download=True,transform=torchvision.transforms.ToTensor())vgg16_false=torchvision.models.vgg16(pretrained=False)
vgg16_true=torchvision.models.vgg16(pretrained=True)# print(vgg16_true)
# import os
#
# # 尝试从环境变量中获取TORCH_HOME
# torch_home = os.getenv('TORCH_HOME', os.path.expanduser('~/.torch'))
# model_cache_dir = os.path.join(torch_home, 'models')
#
# print(f"Model cache directory: {model_cache_dir}")
#
# # 注意:这个目录可能不直接包含模型文件,因为 PyTorch 可能使用了内部的缓存机制
# # 来管理这些文件,并且它们可能以哈希名存储而不是直接以模型名存储。train_data=torchvision.datasets.CIFAR10('./dataset',train=True,download=True,transform=torchvision.transforms.ToTensor())
vgg16_true.classifier.add_module('add_linear',nn.Linear(1000,10))
# print(vgg16_true)vgg16_false.classifier[6]=nn.Linear(4096,10)
print(vgg16_false)
相关文章:
pytorch学习(十二):对现有的模型进行修改
以VGG16为例: VGG((features): Sequential((0): Conv2d(3, 64, kernel_size(3, 3), stride(1, 1), padding(1, 1))(1): ReLU(inplaceTrue)(2): Conv2d(64, 64, kernel_size(3, 3), stride(1, 1), padding(1, 1))(3): ReLU(inplaceTrue)(4): MaxPool2d(kernel_size2…...
服务器虚拟内存是什么?虚拟内存怎么设置?
服务器虚拟内存是计算机系统内存管理的一种重要技术,它允许应用程序认为它们拥有连续且完整的内存地址空间,而实际上这些内存空间是由多个物理内存碎片和外部磁盘存储器上的空间共同组成的。当物理内存(RAM)不足时,系统…...
深度学习入门指南(1) - 从chatgpt入手
2012年,加拿大多伦多大学的Hinton教授带领他的两个学生Alex和Ilya一起用AlexNet撞开了深度学习的大门,从此人类走入了深度学习时代。 2015年,这个第二作者80后Ilya Sutskever参与创建了openai公司。现在Ilya是openai的首席科学家,…...
Python学习笔记(六)
""" 演示对序列进行切片操作 """ # 切片;从一个序列中,取出一个子序列 # 语法[起始下标:结束下标:步长] # 这三个都不写也行,视为从头到尾步长为1 # 起始下标不写,视作从头开…...
大数据安全规划总体方案(45页PPT)
方案介绍: 大数据安全规划总体方案的制定,旨在应对当前大数据环境中存在的各类安全风险,包括但不限于数据泄露、数据篡改、非法访问等。通过构建完善的安全防护体系,保障大数据在采集、存储、处理、传输、共享等全生命周期中的安…...
第20周:Pytorch文本分类入门
目录 前言 一、前期准备 1.1 环境安装导入包 1.2 加载数据 1.3 构建词典 1.4 生成数据批次和迭代器 二、准备模型 2.1 定义模型 2.2 定义示例 2.3 定义训练函数与评估函数 三、训练模型 3.1 拆分数据集并运行模型 3.2 使用测试数据集评估模型 总结 前言 …...
记一次 SpringBoot2.x 配置 Fastjson请求报 internal server 500
1.遇到的问题 报错springboot从2.1.16升级到2.5.15,之后就报500内部错误,后面调用都是正常的,就考虑转换有错。 接口返回错误: 2.解决办法 因为我用了fastjson,需要转换下,目前可能理解就是springboot-we…...
OSPF笔记
OSPF:开放式最短路径优先协议 使用范围:IGP 协议算法特点:链路状态型路由协议,SPF算法 协议是否传递网络掩码:传递网络掩码 协议封装:基于ip协议封装,协议号为89 一,ospf特点 1…...
IOC容器初始化流程
IOC容器初始化流程 一、概要1.准备上下文prepareRefresh()2. 获取beanFactory:obtainFreshBeanFactory()3. 准备beanFactory:prepareBeanFactory(beanFactory)4. 后置处理:postProcessBeanFactory()5. 调用bean工厂后置处理器:invokeBeanFactoryPostProcessors()6. 注册bea…...
第二季度云计算市场份额榜单:微软下滑,谷歌上升,AWS仍保持领先
2024 年第二季度,随着企业云支出达到 790 亿美元的新高,三大云计算巨头微软、谷歌云和 AWS的全球云市场份额发生了变化。 根据新的市场数据,以下是 2024 年第二季度全球云市场份额结果和六大世界领先者,其中包括 AWS、阿里巴巴、…...
三点确定圆心算法推导
已知a,b,c三点求过这三点的圆心坐标 a ( x 1 , y 1 ) a(x_1, y_1) a(x1,y1) 、 b ( x 2 , y 2 ) b(x_2, y_2) b(x2,y2) 、 c ( x 3 , y 3 ) c(x_3, y_3) c(x3,y3) 确认三点是否共线 叉积计算方式 v → ( X 1 , Y 1 ) u → ( X 2 , Y 2 ) X 1 Y 2 − X 2 Y 1 \…...
神经网络 (NN) TensorFlow Playground在线应用程序
神经网络 (NN) 历史上最重要的发现之一是神经网络 (NN) 的强大功能。 在神经网络中,称为神经元的许多数据层被添加在一起或相互堆叠以计算新的数据级别。 常用的简称: DNN 深度神经网络CNN 卷积神经网络RNN 循环神经网络 神经元 科学家一致认为&am…...
腾讯课堂 离线m3u8.sqlite转成视频
为了广大腾讯课堂用户对于购买的课程不能正常离线播放,构成知识付费损失,故出此文档。 重点:完全免费!!!完全免费!!!完全免费!!! 怎么…...
Linux多路转接
文章目录 IO模型多路转接select 和 pollepoll IO模型 在还在学习语言的阶段,C里使用cin,或者是C使用scanf的时候,总是要等着我们输入数据才执行,这种IO是阻塞IO。下面是比较正式的说法。 阻塞IO: 在内核将数据准备好之前…...
IDEA导入Maven项目的流程配置以常见问题解决
1. 前言 本文主要围绕着在IDEA中导入新Maven项目后的配置及常见问题解决来展开说说。相关的部分软件如下: IntelliJ IDEA 2021.1JDK 1.8Window 2. 导入Maven项目及配置 2.1 导入Maven项目 下面介绍了直接打开本地项目和导入git上的项目两种导入Maven方式。 1…...
【数据分析---- Pandas进阶指南:核心计算方法、缺失值处理及数据类型管理】
前言: 💞💞大家好,我是书生♡,本阶段和大家一起分享和探索数据分析,本篇文章主要讲述了:Pandas进阶指南:核心计算方法、缺失值处理及数据类型管理等等。欢迎大家一起探索讨论&#x…...
2024世界机器人大会将于8月21日至25日在京举行
2024年的世界机器人大会预定于8月21日至25日,在北京经济技术开发区的北人亦创国际会展中心隆重举办。 本届大会以“共育新质生产力 共享智能新未来”为核心主题,将汇聚来自全球超过300位的机器人行业专家、国际组织代表、杰出科学家以及企业家࿰…...
【Linux】lvm被删除或者lvm丢失了怎么办
模拟案例 接下来模拟lvm误删除如何恢复的案例: 模拟删除: 查看vg名: vgdisplayvgcfgrestore --list uniontechos #查看之前的操作 例如我删除的,现场没有删除就用最近的操作文件: 还原: vgcfgrestore…...
疫情防控管理系统
摘 要 由于当前疫情防控形势复杂,为做好学校疫情防控管理措施,根据上级防疫部门要求,为了学生的生命安全,要求学校加强疫情防控的管理。为了迎合时代需求,优化管理效率,各种各样的管理系统应运而生&#x…...
永久删除的Android 文件去哪了?在Android上恢复误删除的消息和照片方法?
丢失重要消息和照片可能是一种令人沮丧的经历,尤其是在您没有备份的情况下。但别担心,在本教程中,我们将指导您完成在Android设备上恢复已删除消息和照片的步骤。无论您是不小心删除了它们还是由于软件问题而消失了,这些步骤都可以…...
SI4463射频项目实战:我是如何用WDS3配置工具搞定868MHz双向通信的
SI4463射频项目实战:从WDS3配置到868MHz双向通信的完整实现 在物联网设备开发中,稳定可靠的无线通信是实现设备互联的关键。SI4463作为Silicon Labs推出的一款高性能Sub-GHz射频芯片,凭借其低功耗、高灵敏度和灵活的配置选项,成为…...
5分钟搞定USR-K5模块配置:串口转以太网通讯的保姆级教程
5分钟搞定USR-K5模块配置:串口转以太网通讯的保姆级教程 当你需要在嵌入式系统中快速实现串口设备与以太网的互联时,USR-K5模块是个不错的选择。这款小巧的串口转以太网模块,能够帮助开发者省去复杂的网络协议栈开发工作,特别适合…...
seo排名工具可以提升网站排名吗
SEO排名工具能否提升网站排名?深入解析与实用建议 在当前互联网时代,网站的排名直接影响着其流量和转化率。许多网站主和数字营销人员常常使用SEO排名工具来提升网站的搜索引擎排名。SEO排名工具能否真正提升网站排名呢?本文将从问题分析、原…...
ComfyUI-VideoHelperSuite:构建高性能视频处理管道的异步架构设计
ComfyUI-VideoHelperSuite:构建高性能视频处理管道的异步架构设计 【免费下载链接】ComfyUI-VideoHelperSuite Nodes related to video workflows 项目地址: https://gitcode.com/gh_mirrors/co/ComfyUI-VideoHelperSuite ComfyUI-VideoHelperSuite是一个专门…...
交换机接口全解析:从RJ-45到光纤,一文掌握所有连接技巧
1. 交换机接口基础:认识常见的物理接口类型 第一次拆开交换机包装时,面对密密麻麻的接口面板,新手常会感到无从下手。其实这些接口按照传输介质可分为两大阵营:电口和光口。电口就是我们熟悉的RJ-45接口,而光口则包含…...
春联生成模型安装包制作:一键部署exe工具开发
春联生成模型安装包制作:一键部署exe工具开发 1. 引言 每年春节前,很多朋友都想自己动手写春联,但要么字写得不够好看,要么想不出有新意的词句。现在有了AI春联生成模型,这个问题就简单多了。不过,对于不…...
避坑指南:ZYNQ lwIP Socket TCP服务器开发中,DHCP超时、内存泄漏和任务卡死的调试经验
ZYNQ lwIP TCP服务器开发实战:从实验室到工业环境的稳定性优化 在嵌入式网络开发中,ZYNQ平台结合lwIP协议栈的TCP服务器实现看似简单,但当代码从实验室走向真实工业环境时,开发者往往会遭遇一系列"幽灵问题"——DHCP获取…...
AI绘画作品集:Anything V5图像生成服务实际效果与案例分享
AI绘画作品集:Anything V5图像生成服务实际效果与案例分享 1. 引言:当AI绘画遇见Anything V5 想象一下,你有一个创意在脑海中盘旋——也许是一个穿着宇航服在咖啡馆里喝咖啡的熊猫,或者是一座漂浮在云端的蒸汽朋克城市。在过去&…...
5个PathPicker高级技巧:掌握$F令牌与自定义命令的终极指南
5个PathPicker高级技巧:掌握$F令牌与自定义命令的终极指南 【免费下载链接】PathPicker PathPicker accepts a wide range of input -- output from git commands, grep results, searches -- pretty much anything. After parsing the input, PathPicker presents …...
SEO_ 揭秘影响搜索引擎排名的核心因素与算法
SEO核心因素解析:揭秘影响搜索引擎排名的算法 在互联网时代,搜索引擎优化(SEO)已成为每一个网站运营者的重要关注点。SEO不仅关系到网站的流量,更直接影响到网站的知名度和商业价值。究竟有哪些核心因素和算法影响着搜…...
