pytorch学习(十二):对现有的模型进行修改
以VGG16为例:
VGG((features): Sequential((0): Conv2d(3, 64, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))(1): ReLU(inplace=True)(2): Conv2d(64, 64, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))(3): ReLU(inplace=True)(4): MaxPool2d(kernel_size=2, stride=2, padding=0, dilation=1, ceil_mode=False)(5): Conv2d(64, 128, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))(6): ReLU(inplace=True)(7): Conv2d(128, 128, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))(8): ReLU(inplace=True)(9): MaxPool2d(kernel_size=2, stride=2, padding=0, dilation=1, ceil_mode=False)(10): Conv2d(128, 256, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))(11): ReLU(inplace=True)(12): Conv2d(256, 256, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))(13): ReLU(inplace=True)(14): Conv2d(256, 256, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))(15): ReLU(inplace=True)(16): MaxPool2d(kernel_size=2, stride=2, padding=0, dilation=1, ceil_mode=False)(17): Conv2d(256, 512, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))(18): ReLU(inplace=True)(19): Conv2d(512, 512, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))(20): ReLU(inplace=True)(21): Conv2d(512, 512, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))(22): ReLU(inplace=True)(23): MaxPool2d(kernel_size=2, stride=2, padding=0, dilation=1, ceil_mode=False)(24): Conv2d(512, 512, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))(25): ReLU(inplace=True)(26): Conv2d(512, 512, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))(27): ReLU(inplace=True)(28): Conv2d(512, 512, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))(29): ReLU(inplace=True)(30): MaxPool2d(kernel_size=2, stride=2, padding=0, dilation=1, ceil_mode=False))(avgpool): AdaptiveAvgPool2d(output_size=(7, 7))(classifier): Sequential((0): Linear(in_features=25088, out_features=4096, bias=True)(1): ReLU(inplace=True)(2): Dropout(p=0.5, inplace=False)(3): Linear(in_features=4096, out_features=4096, bias=True)(4): ReLU(inplace=True)(5): Dropout(p=0.5, inplace=False)(6): Linear(in_features=4096, out_features=1000, bias=True))
)
特征提取部分(features)
-
卷积层与ReLU激活:网络的前半部分主要由卷积层(
Conv2d)和ReLU激活函数(ReLU)交替组成。每个卷积层后都紧跟一个ReLU层,用于引入非线性。这种结构有助于网络学习复杂的特征表示。 -
卷积层配置:
- 初始阶段,使用64个3x3的卷积核,然后是ReLU激活,接着是另一个3x3卷积核和ReLU激活,之后是一个2x2的最大池化层(
MaxPool2d),用于降低特征图的尺寸并增加感受野。 - 类似地,这个过程在特征图的通道数增加到128、256和512时重复,每次增加通道数后都会跟随几个卷积层和ReLU激活,然后是一个最大池化层。
- 值得注意的是,在512通道的部分,卷积层和ReLU激活的组合被重复了三次,而没有立即进行池化,这可能是为了进一步增强特征表示。
- 初始阶段,使用64个3x3的卷积核,然后是ReLU激活,接着是另一个3x3卷积核和ReLU激活,之后是一个2x2的最大池化层(
-
最大池化层:用于在每个阶段的末尾减少特征图的尺寸,这有助于减少计算量和参数数量,同时保持重要的特征信息。
全连接层部分(classifier)
-
自适应平均池化:在特征提取部分之后,使用了一个自适应平均池化层(
AdaptiveAvgPool2d),将特征图的尺寸调整为7x7。这是为了确保无论输入图像的大小如何,全连接层都能接收到固定大小的输入。 -
全连接层:
- 第一个全连接层(
Linear)将7x7x512的特征图展平为25088个特征,并映射到4096个输出特征上。 - 接着是两个ReLU激活层、两个Dropout层(用于防止过拟合)和另外两个全连接层,最终输出1000个类别的得分(假设是用于ImageNet分类任务)。
- 第一个全连接层(
可以看到,最后一层的全连接层的输出是1000,那么当我们有例如十分类的问题时候,就需要对网络进行修改。
vgg16_true.add_module('add_linear',nn.Linear(1000,10))
运行上述代码,在末尾加一层线性层,也就是全连接层。
还有一种方式是对原有的全连接层进行修改,将1000改为10。
vgg16_false.classifier[6]=nn.Linear(4096,10)
附上所有源代码;
# -*- coding: utf-8 -*-
# File created on 2024/8/9
# 作者:酷尔
# 公众号:酷尔计算机import torchvision
from torch import nn
# train_data=torchvision.datasets.ImageNet('./data_imagenet',split='train',download=True,transform=torchvision.transforms.ToTensor())vgg16_false=torchvision.models.vgg16(pretrained=False)
vgg16_true=torchvision.models.vgg16(pretrained=True)# print(vgg16_true)
# import os
#
# # 尝试从环境变量中获取TORCH_HOME
# torch_home = os.getenv('TORCH_HOME', os.path.expanduser('~/.torch'))
# model_cache_dir = os.path.join(torch_home, 'models')
#
# print(f"Model cache directory: {model_cache_dir}")
#
# # 注意:这个目录可能不直接包含模型文件,因为 PyTorch 可能使用了内部的缓存机制
# # 来管理这些文件,并且它们可能以哈希名存储而不是直接以模型名存储。train_data=torchvision.datasets.CIFAR10('./dataset',train=True,download=True,transform=torchvision.transforms.ToTensor())
vgg16_true.classifier.add_module('add_linear',nn.Linear(1000,10))
# print(vgg16_true)vgg16_false.classifier[6]=nn.Linear(4096,10)
print(vgg16_false)
相关文章:
pytorch学习(十二):对现有的模型进行修改
以VGG16为例: VGG((features): Sequential((0): Conv2d(3, 64, kernel_size(3, 3), stride(1, 1), padding(1, 1))(1): ReLU(inplaceTrue)(2): Conv2d(64, 64, kernel_size(3, 3), stride(1, 1), padding(1, 1))(3): ReLU(inplaceTrue)(4): MaxPool2d(kernel_size2…...
服务器虚拟内存是什么?虚拟内存怎么设置?
服务器虚拟内存是计算机系统内存管理的一种重要技术,它允许应用程序认为它们拥有连续且完整的内存地址空间,而实际上这些内存空间是由多个物理内存碎片和外部磁盘存储器上的空间共同组成的。当物理内存(RAM)不足时,系统…...
深度学习入门指南(1) - 从chatgpt入手
2012年,加拿大多伦多大学的Hinton教授带领他的两个学生Alex和Ilya一起用AlexNet撞开了深度学习的大门,从此人类走入了深度学习时代。 2015年,这个第二作者80后Ilya Sutskever参与创建了openai公司。现在Ilya是openai的首席科学家,…...
Python学习笔记(六)
""" 演示对序列进行切片操作 """ # 切片;从一个序列中,取出一个子序列 # 语法[起始下标:结束下标:步长] # 这三个都不写也行,视为从头到尾步长为1 # 起始下标不写,视作从头开…...
大数据安全规划总体方案(45页PPT)
方案介绍: 大数据安全规划总体方案的制定,旨在应对当前大数据环境中存在的各类安全风险,包括但不限于数据泄露、数据篡改、非法访问等。通过构建完善的安全防护体系,保障大数据在采集、存储、处理、传输、共享等全生命周期中的安…...
第20周:Pytorch文本分类入门
目录 前言 一、前期准备 1.1 环境安装导入包 1.2 加载数据 1.3 构建词典 1.4 生成数据批次和迭代器 二、准备模型 2.1 定义模型 2.2 定义示例 2.3 定义训练函数与评估函数 三、训练模型 3.1 拆分数据集并运行模型 3.2 使用测试数据集评估模型 总结 前言 …...
记一次 SpringBoot2.x 配置 Fastjson请求报 internal server 500
1.遇到的问题 报错springboot从2.1.16升级到2.5.15,之后就报500内部错误,后面调用都是正常的,就考虑转换有错。 接口返回错误: 2.解决办法 因为我用了fastjson,需要转换下,目前可能理解就是springboot-we…...
OSPF笔记
OSPF:开放式最短路径优先协议 使用范围:IGP 协议算法特点:链路状态型路由协议,SPF算法 协议是否传递网络掩码:传递网络掩码 协议封装:基于ip协议封装,协议号为89 一,ospf特点 1…...
IOC容器初始化流程
IOC容器初始化流程 一、概要1.准备上下文prepareRefresh()2. 获取beanFactory:obtainFreshBeanFactory()3. 准备beanFactory:prepareBeanFactory(beanFactory)4. 后置处理:postProcessBeanFactory()5. 调用bean工厂后置处理器:invokeBeanFactoryPostProcessors()6. 注册bea…...
第二季度云计算市场份额榜单:微软下滑,谷歌上升,AWS仍保持领先
2024 年第二季度,随着企业云支出达到 790 亿美元的新高,三大云计算巨头微软、谷歌云和 AWS的全球云市场份额发生了变化。 根据新的市场数据,以下是 2024 年第二季度全球云市场份额结果和六大世界领先者,其中包括 AWS、阿里巴巴、…...
三点确定圆心算法推导
已知a,b,c三点求过这三点的圆心坐标 a ( x 1 , y 1 ) a(x_1, y_1) a(x1,y1) 、 b ( x 2 , y 2 ) b(x_2, y_2) b(x2,y2) 、 c ( x 3 , y 3 ) c(x_3, y_3) c(x3,y3) 确认三点是否共线 叉积计算方式 v → ( X 1 , Y 1 ) u → ( X 2 , Y 2 ) X 1 Y 2 − X 2 Y 1 \…...
神经网络 (NN) TensorFlow Playground在线应用程序
神经网络 (NN) 历史上最重要的发现之一是神经网络 (NN) 的强大功能。 在神经网络中,称为神经元的许多数据层被添加在一起或相互堆叠以计算新的数据级别。 常用的简称: DNN 深度神经网络CNN 卷积神经网络RNN 循环神经网络 神经元 科学家一致认为&am…...
腾讯课堂 离线m3u8.sqlite转成视频
为了广大腾讯课堂用户对于购买的课程不能正常离线播放,构成知识付费损失,故出此文档。 重点:完全免费!!!完全免费!!!完全免费!!! 怎么…...
Linux多路转接
文章目录 IO模型多路转接select 和 pollepoll IO模型 在还在学习语言的阶段,C里使用cin,或者是C使用scanf的时候,总是要等着我们输入数据才执行,这种IO是阻塞IO。下面是比较正式的说法。 阻塞IO: 在内核将数据准备好之前…...
IDEA导入Maven项目的流程配置以常见问题解决
1. 前言 本文主要围绕着在IDEA中导入新Maven项目后的配置及常见问题解决来展开说说。相关的部分软件如下: IntelliJ IDEA 2021.1JDK 1.8Window 2. 导入Maven项目及配置 2.1 导入Maven项目 下面介绍了直接打开本地项目和导入git上的项目两种导入Maven方式。 1…...
【数据分析---- Pandas进阶指南:核心计算方法、缺失值处理及数据类型管理】
前言: 💞💞大家好,我是书生♡,本阶段和大家一起分享和探索数据分析,本篇文章主要讲述了:Pandas进阶指南:核心计算方法、缺失值处理及数据类型管理等等。欢迎大家一起探索讨论&#x…...
2024世界机器人大会将于8月21日至25日在京举行
2024年的世界机器人大会预定于8月21日至25日,在北京经济技术开发区的北人亦创国际会展中心隆重举办。 本届大会以“共育新质生产力 共享智能新未来”为核心主题,将汇聚来自全球超过300位的机器人行业专家、国际组织代表、杰出科学家以及企业家࿰…...
【Linux】lvm被删除或者lvm丢失了怎么办
模拟案例 接下来模拟lvm误删除如何恢复的案例: 模拟删除: 查看vg名: vgdisplayvgcfgrestore --list uniontechos #查看之前的操作 例如我删除的,现场没有删除就用最近的操作文件: 还原: vgcfgrestore…...
疫情防控管理系统
摘 要 由于当前疫情防控形势复杂,为做好学校疫情防控管理措施,根据上级防疫部门要求,为了学生的生命安全,要求学校加强疫情防控的管理。为了迎合时代需求,优化管理效率,各种各样的管理系统应运而生&#x…...
永久删除的Android 文件去哪了?在Android上恢复误删除的消息和照片方法?
丢失重要消息和照片可能是一种令人沮丧的经历,尤其是在您没有备份的情况下。但别担心,在本教程中,我们将指导您完成在Android设备上恢复已删除消息和照片的步骤。无论您是不小心删除了它们还是由于软件问题而消失了,这些步骤都可以…...
装饰模式(Decorator Pattern)重构java邮件发奖系统实战
前言 现在我们有个如下的需求,设计一个邮件发奖的小系统, 需求 1.数据验证 → 2. 敏感信息加密 → 3. 日志记录 → 4. 实际发送邮件 装饰器模式(Decorator Pattern)允许向一个现有的对象添加新的功能,同时又不改变其…...
CTF show Web 红包题第六弹
提示 1.不是SQL注入 2.需要找关键源码 思路 进入页面发现是一个登录框,很难让人不联想到SQL注入,但提示都说了不是SQL注入,所以就不往这方面想了 先查看一下网页源码,发现一段JavaScript代码,有一个关键类ctfs…...
大型活动交通拥堵治理的视觉算法应用
大型活动下智慧交通的视觉分析应用 一、背景与挑战 大型活动(如演唱会、马拉松赛事、高考中考等)期间,城市交通面临瞬时人流车流激增、传统摄像头模糊、交通拥堵识别滞后等问题。以演唱会为例,暖城商圈曾因观众集中离场导致周边…...
安宝特方案丨船舶智造的“AR+AI+作业标准化管理解决方案”(装配)
船舶制造装配管理现状:装配工作依赖人工经验,装配工人凭借长期实践积累的操作技巧完成零部件组装。企业通常制定了装配作业指导书,但在实际执行中,工人对指导书的理解和遵循程度参差不齐。 船舶装配过程中的挑战与需求 挑战 (1…...
Java求职者面试指南:Spring、Spring Boot、MyBatis框架与计算机基础问题解析
Java求职者面试指南:Spring、Spring Boot、MyBatis框架与计算机基础问题解析 一、第一轮提问(基础概念问题) 1. 请解释Spring框架的核心容器是什么?它在Spring中起到什么作用? Spring框架的核心容器是IoC容器&#…...
如何应对敏捷转型中的团队阻力
应对敏捷转型中的团队阻力需要明确沟通敏捷转型目的、提升团队参与感、提供充分的培训与支持、逐步推进敏捷实践、建立清晰的奖励和反馈机制。其中,明确沟通敏捷转型目的尤为关键,团队成员只有清晰理解转型背后的原因和利益,才能降低对变化的…...
华为OD机试-最短木板长度-二分法(A卷,100分)
此题是一个最大化最小值的典型例题, 因为搜索范围是有界的,上界最大木板长度补充的全部木料长度,下界最小木板长度; 即left0,right10^6; 我们可以设置一个候选值x(mid),将木板的长度全部都补充到x,如果成功…...
Monorepo架构: Nx Cloud 扩展能力与缓存加速
借助 Nx Cloud 实现项目协同与加速构建 1 ) 缓存工作原理分析 在了解了本地缓存和远程缓存之后,我们来探究缓存是如何工作的。以计算文件的哈希串为例,若后续运行任务时文件哈希串未变,系统会直接使用对应的输出和制品文件。 2 …...
PH热榜 | 2025-06-08
1. Thiings 标语:一套超过1900个免费AI生成的3D图标集合 介绍:Thiings是一个不断扩展的免费AI生成3D图标库,目前已有超过1900个图标。你可以按照主题浏览,生成自己的图标,或者下载整个图标集。所有图标都可以在个人或…...
用js实现常见排序算法
以下是几种常见排序算法的 JS实现,包括选择排序、冒泡排序、插入排序、快速排序和归并排序,以及每种算法的特点和复杂度分析 1. 选择排序(Selection Sort) 核心思想:每次从未排序部分选择最小元素,与未排…...
