当前位置: 首页 > news >正文

pytorch学习(十二):对现有的模型进行修改

以VGG16为例:

VGG((features): Sequential((0): Conv2d(3, 64, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))(1): ReLU(inplace=True)(2): Conv2d(64, 64, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))(3): ReLU(inplace=True)(4): MaxPool2d(kernel_size=2, stride=2, padding=0, dilation=1, ceil_mode=False)(5): Conv2d(64, 128, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))(6): ReLU(inplace=True)(7): Conv2d(128, 128, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))(8): ReLU(inplace=True)(9): MaxPool2d(kernel_size=2, stride=2, padding=0, dilation=1, ceil_mode=False)(10): Conv2d(128, 256, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))(11): ReLU(inplace=True)(12): Conv2d(256, 256, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))(13): ReLU(inplace=True)(14): Conv2d(256, 256, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))(15): ReLU(inplace=True)(16): MaxPool2d(kernel_size=2, stride=2, padding=0, dilation=1, ceil_mode=False)(17): Conv2d(256, 512, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))(18): ReLU(inplace=True)(19): Conv2d(512, 512, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))(20): ReLU(inplace=True)(21): Conv2d(512, 512, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))(22): ReLU(inplace=True)(23): MaxPool2d(kernel_size=2, stride=2, padding=0, dilation=1, ceil_mode=False)(24): Conv2d(512, 512, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))(25): ReLU(inplace=True)(26): Conv2d(512, 512, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))(27): ReLU(inplace=True)(28): Conv2d(512, 512, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))(29): ReLU(inplace=True)(30): MaxPool2d(kernel_size=2, stride=2, padding=0, dilation=1, ceil_mode=False))(avgpool): AdaptiveAvgPool2d(output_size=(7, 7))(classifier): Sequential((0): Linear(in_features=25088, out_features=4096, bias=True)(1): ReLU(inplace=True)(2): Dropout(p=0.5, inplace=False)(3): Linear(in_features=4096, out_features=4096, bias=True)(4): ReLU(inplace=True)(5): Dropout(p=0.5, inplace=False)(6): Linear(in_features=4096, out_features=1000, bias=True))
)

特征提取部分(features

  • 卷积层与ReLU激活:网络的前半部分主要由卷积层(Conv2d)和ReLU激活函数(ReLU)交替组成。每个卷积层后都紧跟一个ReLU层,用于引入非线性。这种结构有助于网络学习复杂的特征表示。

  • 卷积层配置

    • 初始阶段,使用64个3x3的卷积核,然后是ReLU激活,接着是另一个3x3卷积核和ReLU激活,之后是一个2x2的最大池化层(MaxPool2d),用于降低特征图的尺寸并增加感受野。
    • 类似地,这个过程在特征图的通道数增加到128、256和512时重复,每次增加通道数后都会跟随几个卷积层和ReLU激活,然后是一个最大池化层。
    • 值得注意的是,在512通道的部分,卷积层和ReLU激活的组合被重复了三次,而没有立即进行池化,这可能是为了进一步增强特征表示。
  • 最大池化层:用于在每个阶段的末尾减少特征图的尺寸,这有助于减少计算量和参数数量,同时保持重要的特征信息。

全连接层部分(classifier

  • 自适应平均池化:在特征提取部分之后,使用了一个自适应平均池化层(AdaptiveAvgPool2d),将特征图的尺寸调整为7x7。这是为了确保无论输入图像的大小如何,全连接层都能接收到固定大小的输入。

  • 全连接层

    • 第一个全连接层(Linear)将7x7x512的特征图展平为25088个特征,并映射到4096个输出特征上。
    • 接着是两个ReLU激活层、两个Dropout层(用于防止过拟合)和另外两个全连接层,最终输出1000个类别的得分(假设是用于ImageNet分类任务)。

可以看到,最后一层的全连接层的输出是1000,那么当我们有例如十分类的问题时候,就需要对网络进行修改。

 

vgg16_true.add_module('add_linear',nn.Linear(1000,10))

运行上述代码,在末尾加一层线性层,也就是全连接层。

还有一种方式是对原有的全连接层进行修改,将1000改为10。

vgg16_false.classifier[6]=nn.Linear(4096,10)

附上所有源代码;

# -*- coding: utf-8 -*-  
# File created on 2024/8/9 
# 作者:酷尔
# 公众号:酷尔计算机import torchvision
from torch import nn
# train_data=torchvision.datasets.ImageNet('./data_imagenet',split='train',download=True,transform=torchvision.transforms.ToTensor())vgg16_false=torchvision.models.vgg16(pretrained=False)
vgg16_true=torchvision.models.vgg16(pretrained=True)# print(vgg16_true)
# import os
#
# # 尝试从环境变量中获取TORCH_HOME
# torch_home = os.getenv('TORCH_HOME', os.path.expanduser('~/.torch'))
# model_cache_dir = os.path.join(torch_home, 'models')
#
# print(f"Model cache directory: {model_cache_dir}")
# 
# # 注意:这个目录可能不直接包含模型文件,因为 PyTorch 可能使用了内部的缓存机制
# # 来管理这些文件,并且它们可能以哈希名存储而不是直接以模型名存储。train_data=torchvision.datasets.CIFAR10('./dataset',train=True,download=True,transform=torchvision.transforms.ToTensor())
vgg16_true.classifier.add_module('add_linear',nn.Linear(1000,10))
# print(vgg16_true)vgg16_false.classifier[6]=nn.Linear(4096,10)
print(vgg16_false)

相关文章:

pytorch学习(十二):对现有的模型进行修改

以VGG16为例: VGG((features): Sequential((0): Conv2d(3, 64, kernel_size(3, 3), stride(1, 1), padding(1, 1))(1): ReLU(inplaceTrue)(2): Conv2d(64, 64, kernel_size(3, 3), stride(1, 1), padding(1, 1))(3): ReLU(inplaceTrue)(4): MaxPool2d(kernel_size2…...

服务器虚拟内存是什么?虚拟内存怎么设置?

服务器虚拟内存是计算机系统内存管理的一种重要技术,它允许应用程序认为它们拥有连续且完整的内存地址空间,而实际上这些内存空间是由多个物理内存碎片和外部磁盘存储器上的空间共同组成的。当物理内存(RAM)不足时,系统…...

深度学习入门指南(1) - 从chatgpt入手

2012年,加拿大多伦多大学的Hinton教授带领他的两个学生Alex和Ilya一起用AlexNet撞开了深度学习的大门,从此人类走入了深度学习时代。 2015年,这个第二作者80后Ilya Sutskever参与创建了openai公司。现在Ilya是openai的首席科学家,…...

Python学习笔记(六)

""" 演示对序列进行切片操作 """ # 切片;从一个序列中,取出一个子序列 # 语法[起始下标:结束下标:步长] # 这三个都不写也行,视为从头到尾步长为1 # 起始下标不写,视作从头开…...

大数据安全规划总体方案(45页PPT)

方案介绍: 大数据安全规划总体方案的制定,旨在应对当前大数据环境中存在的各类安全风险,包括但不限于数据泄露、数据篡改、非法访问等。通过构建完善的安全防护体系,保障大数据在采集、存储、处理、传输、共享等全生命周期中的安…...

第20周:Pytorch文本分类入门

目录 前言 一、前期准备 1.1 环境安装导入包 1.2 加载数据 1.3 构建词典 1.4 生成数据批次和迭代器 二、准备模型 2.1 定义模型 2.2 定义示例 2.3 定义训练函数与评估函数 三、训练模型 3.1 拆分数据集并运行模型 3.2 使用测试数据集评估模型 总结 前言 &#x1…...

记一次 SpringBoot2.x 配置 Fastjson请求报 internal server 500

1.遇到的问题 报错springboot从2.1.16升级到2.5.15,之后就报500内部错误,后面调用都是正常的,就考虑转换有错。 接口返回错误: 2.解决办法 因为我用了fastjson,需要转换下,目前可能理解就是springboot-we…...

OSPF笔记

OSPF:开放式最短路径优先协议 使用范围:IGP 协议算法特点:链路状态型路由协议,SPF算法 协议是否传递网络掩码:传递网络掩码 协议封装:基于ip协议封装,协议号为89 一,ospf特点 1…...

IOC容器初始化流程

IOC容器初始化流程 一、概要1.准备上下文prepareRefresh()2. 获取beanFactory:obtainFreshBeanFactory()3. 准备beanFactory:prepareBeanFactory(beanFactory)4. 后置处理:postProcessBeanFactory()5. 调用bean工厂后置处理器:invokeBeanFactoryPostProcessors()6. 注册bea…...

第二季度云计算市场份额榜单:微软下滑,谷歌上升,AWS仍保持领先

2024 年第二季度,随着企业云支出达到 790 亿美元的新高,三大云计算巨头微软、谷歌云和 AWS的全球云市场份额发生了变化。 根据新的市场数据,以下是 2024 年第二季度全球云市场份额结果和六大世界领先者,其中包括 AWS、阿里巴巴、…...

三点确定圆心算法推导

已知a,b,c三点求过这三点的圆心坐标 a ( x 1 , y 1 ) a(x_1, y_1) a(x1​,y1​) 、 b ( x 2 , y 2 ) b(x_2, y_2) b(x2​,y2​) 、 c ( x 3 , y 3 ) c(x_3, y_3) c(x3​,y3​) 确认三点是否共线 叉积计算方式 v → ( X 1 , Y 1 ) u → ( X 2 , Y 2 ) X 1 Y 2 − X 2 Y 1 \…...

神经网络 (NN) TensorFlow Playground在线应用程序

神经网络 (NN) 历史上最重要的发现之一是神经网络 (NN) 的强大功能。 在神经网络中,称为神经元的许多数据层被添加在一起或相互堆叠以计算新的数据级别。 常用的简称: DNN 深度神经网络CNN 卷积神经网络RNN 循环神经网络 神经元 科学家一致认为&am…...

腾讯课堂 离线m3u8.sqlite转成视频

为了广大腾讯课堂用户对于购买的课程不能正常离线播放,构成知识付费损失,故出此文档。 重点:完全免费!!!完全免费!!!完全免费!!! 怎么…...

Linux多路转接

文章目录 IO模型多路转接select 和 pollepoll IO模型 在还在学习语言的阶段,C里使用cin,或者是C使用scanf的时候,总是要等着我们输入数据才执行,这种IO是阻塞IO。下面是比较正式的说法。 阻塞IO: 在内核将数据准备好之前&#xf…...

IDEA导入Maven项目的流程配置以常见问题解决

1. 前言 本文主要围绕着在IDEA中导入新Maven项目后的配置及常见问题解决来展开说说。相关的部分软件如下: IntelliJ IDEA 2021.1JDK 1.8Window 2. 导入Maven项目及配置 2.1 导入Maven项目 下面介绍了直接打开本地项目和导入git上的项目两种导入Maven方式。 1…...

【数据分析---- Pandas进阶指南:核心计算方法、缺失值处理及数据类型管理】

前言: 💞💞大家好,我是书生♡,本阶段和大家一起分享和探索数据分析,本篇文章主要讲述了:Pandas进阶指南:核心计算方法、缺失值处理及数据类型管理等等。欢迎大家一起探索讨论&#x…...

2024世界机器人大会将于8月21日至25日在京举行

2024年的世界机器人大会预定于8月21日至25日,在北京经济技术开发区的北人亦创国际会展中心隆重举办。 本届大会以“共育新质生产力 共享智能新未来”为核心主题,将汇聚来自全球超过300位的机器人行业专家、国际组织代表、杰出科学家以及企业家&#xff0…...

【Linux】lvm被删除或者lvm丢失了怎么办

模拟案例 接下来模拟lvm误删除如何恢复的案例: 模拟删除: 查看vg名: vgdisplayvgcfgrestore --list uniontechos #查看之前的操作 例如我删除的,现场没有删除就用最近的操作文件: 还原: vgcfgrestore…...

疫情防控管理系统

摘 要 由于当前疫情防控形势复杂,为做好学校疫情防控管理措施,根据上级防疫部门要求,为了学生的生命安全,要求学校加强疫情防控的管理。为了迎合时代需求,优化管理效率,各种各样的管理系统应运而生&#x…...

永久删除的Android 文件去哪了?在Android上恢复误删除的消息和照片方法?

丢失重要消息和照片可能是一种令人沮丧的经历,尤其是在您没有备份的情况下。但别担心,在本教程中,我们将指导您完成在Android设备上恢复已删除消息和照片的步骤。无论您是不小心删除了它们还是由于软件问题而消失了,这些步骤都可以…...

Qwen3-14B私有部署镜像:Android Studio移动端AI应用原型开发

Qwen3-14B私有部署镜像:Android Studio移动端AI应用原型开发 1. 移动端AI应用开发新选择 最近在开发一个需要集成大语言模型的Android应用时,发现很多开发者都在寻找既强大又容易集成的AI解决方案。Qwen3-14B作为一款性能优异的中文大模型,…...

Python玩转微信自动化:除了监控聊天,uiautomation还能帮你自动保存文件、整理聊天记录

Python实现微信自动化管理:从文件归档到聊天记录整理 微信已经成为现代办公不可或缺的沟通工具,但随之而来的是海量文件管理和聊天记录整理的烦恼。每天手动保存图片、文档,再按日期分类,不仅耗时耗力,还容易遗漏重要…...

告别‘Setup is running...’卡死!保姆级PowerBuilder 9.0安装避坑指南(附安全模式备用方案)

PowerBuilder 9.0安装全攻略:从卡死困境到流畅部署的终极解决方案 如果你曾经在安装PowerBuilder 9.0时遭遇过"Setup is running..."的无限卡死,那么这篇文章就是为你量身定制的救星。作为一款经典的企业级开发工具,PowerBuilder至…...

Phi-4-mini-reasoning加速深度学习:卷积神经网络(CNN)模型设计与调优实战

Phi-4-mini-reasoning加速深度学习:卷积神经网络(CNN)模型设计与调优实战 1. 引言:当AI开始设计AI 在图像分类任务中,我们常常陷入这样的困境:面对海量的网络结构选择和超参数组合,即使是有经…...

CCS12.2搭配C2000ware 4.03导入工程报错?手把手教你修复头文件路径变量(MATLAB 2023b适用)

CCS12.2与C2000ware 4.03工程导入报错全解析:从路径变量修复到MATLAB 2023b联调实战 当你满怀期待地将MATLAB 2023b生成的代码导入CCS12.2,准备与C2000ware 4.03来场完美邂逅时,编译器却毫不留情地抛出一连串头文件找不到的错误——这种从云端…...

Pixel Script Temple多场景落地:政务宣传短视频、乡村振兴纪录片脚本生成

Pixel Script Temple多场景落地:政务宣传短视频、乡村振兴纪录片脚本生成 1. 专业剧本创作工具介绍 Pixel Script Temple(像素剧本圣殿)是一款基于Qwen2.5-14B-Instruct大模型深度优化的专业剧本创作工具。它将先进的AI推理能力与独特的8-B…...

RexUniNLU开源模型实战:400MB模型在A10/A100/T4不同GPU上的适配

RexUniNLU开源模型实战:400MB模型在A10/A100/T4不同GPU上的适配 1. 引言 你是否遇到过这样的困扰:想要使用强大的自然语言理解模型,但动辄几十GB的大模型让部署变得困难重重?或者你的GPU显存有限,无法运行那些"…...

Obsidian-skills日志系统:如何记录和分析AI技能使用情况

Obsidian-skills日志系统:如何记录和分析AI技能使用情况 【免费下载链接】obsidian-skills Agent skills for Obsidian. Teach your agent to use Markdown, Bases, JSON Canvas, and use the CLI. 项目地址: https://gitcode.com/GitHub_Trending/ob/obsidian-sk…...

基于Simulink的Smith预估器PID整定与延迟系统控制实验

1. 从零开始理解Smith预估控制 第一次接触Smith预估器时,我也被这个"时间旅行"般的概念惊艳到了。想象一下,你正在用热水器洗澡,每次调节水温都要等10秒才能感受到变化——这就是典型的纯延迟系统。Smith预估器的精妙之处在于&…...

OpenClaw开发助手:Qwen3.5-9B支持的代码调试与日志分析

OpenClaw开发助手:Qwen3.5-9B支持的代码调试与日志分析 1. 为什么开发者需要AI辅助调试? 深夜两点,我盯着终端里不断刷新的错误日志,第17次尝试修复那个诡异的空指针异常。咖啡杯早已见底,而问题依然像迷宫般无解——…...