当前位置: 首页 > news >正文

5_现有网络模型的使用

 教程:现有网络模型的使用及修改_哔哩哔哩_bilibili

官方网址:https://pytorch.org/vision/stable/models.html#classification

 初识网络模型

pytorch为我们提供了许多已经构造好的网络模型,我们只要将它们加载进来,就可以直接使用。以torchvision为例,关于神经网络处理图像的模型就分为好几个大类:如图像分类、目标检测、语义分割等等。如图所示:

 视频中的讲解以VGG模型为例,来向我们展示了网络模型的使用。

因为这个教学视频也已经是两三年前了的,现在和之前略微有所区别。在这里,简单做一个说明:比如说模型加载过程中参数的改变:

如今的模型中不再有pretrained参数,也就是如果大家需要下载模型的权重文件,需要自己手动下载。务必注意,写了会报错哦。 

权重文件的下载

 视频中有讲到模型的下载也是不大不小的,如果不进行设置,一般会默认下载在c盘,想要进行设置的话,可以在网上搜索有关代码:Pytorch预训练模型下载并加载(以VGG为例)自定义路径_怎么更改vgg下载路径-CSDN博客

但以上这位同学的方法我使用时出错,提示我没有这个属性:

model_zoo._download_url_to_file(url, os.path.join(dst_path, filename), hash_prefix, True)
AttributeError: module 'torch.utils.model_zoo' has no attribute '_download_url_to_file'

所以我略加修改,以下是我的处理下载过程,同样出错的同学可以看看:

from urllib.parse import urlparse
import torch
# import re
import os
def download_model(url, dst_path):parts = urlparse(url)filename = os.path.basename(parts.path)# HASH_REGEX = re.compile(r'-([a-f0-9]*)\.')# hash_prefix = HASH_REGEX.search(filename).group(1)torch.hub.download_url_to_file(url, os.path.join(dst_path, filename))return filenamepath = "D:\\vscodeProjects\\models"
if not (os.path.exists(path)):os.makedirs(path)
url='https://download.pytorch.org/models/vgg16-397923af.pth'
download_model(url, path)

 只是这个下载的速度着实太慢,我先放弃了:

 关于这个权重文件的下载我犯了一点小迷糊。我有点搞不懂为什么费劲巴拉下载这么大个东西然后视频中又仅仅使用vgg16=torchvision.models.vgg16()这一句话就完事了。

于是我搜索了一下:

  • 在 PyTorch 中,许多流行的深度学习模型(如 VGG、ResNet、AlexNet 等)都有预先训练好的权重文件可供下载。这些权重文件包含了模型在大规模数据集(如 ImageNet)上训练的参数,可以帮助加快模型的收敛速度,提升模型的表现。下载预训练模型通常是为了避免从头开始训练模型,节省时间和计算资源。
  • torchvision.models 是 PyTorch 提供的一个模块,用于加载常见的计算机视觉模型,例如 VGG、ResNet、AlexNet 等。这些模型可以通过简单的调用来导入,并且可以选择加载预训练的权重。

 简而言之,权重文件可以简化我们模型的训练过程,我们可以通过使用权重文件来直接利用前辈的训练结果,稍作修改就可以变成我们自己的东西。

如果只是用vgg16=torchvision.models.vgg16()这么一句话来加载网络模型,得到的模型只有结构而没有经过训练的过程,因此它的权重是初始的。

网络模型的修改

因为官网中提到的VGG模型的官配数据集ImageNet实在是太大了(100+个G),笔记本实在带不了,所以还是使用我们之前已经用了很多次的数据集CIFAR10来搞,正好可以讲解一下怎样修改网络模型。

原官配数据集非常之大(对我一个初学者来说,是暂时见过最大的数据集了),最终一共分为1000个类。因此这个VGG模型最终输出为1000,为了适配于我们这个CIFAR10数据集(输出只有10类),我们为加载下来的VGG模型添加一个线性层,将原本的1000个类最终输出为10类。

from torch import nn
import torchvision
vgg16=torchvision.models.vgg16()
train_data=torchvision.datasets.CIFAR10("../dataset",train=True,transform=torchvision.transforms.ToTensor())
vgg16.add_module('add_linear',nn.Linear(1000,10))

print(vgg16)可以看到,最下面就是我们新添加的层:

 如果我们想添加在classifier这个模型中,我们也可以这样写:

vgg16.classifier.add_module('add_linear',nn.Linear(1000,10))

同样打印一下看效果:

 当然如果我们不想添加新的一层,我们也可以通过另外的一种方式来将输出从1000改为10:

如上图所示,已知最后一层是线性层,输入4096,输出1000,那么我们现在直接将最后一个线性层修改,输出改成10:

vgg16.classifier[6]=nn.Linear(in_features=4096,out_features=10,bias=True)

看结果:

模型的保存和加载

如果我们对网络模型进行了修改或者训练,如何将我们自己的模型保存下来呢?一共有以下两种方式:

vgg16=torchvision.models.vgg16()
vgg16.classifier[6]=nn.Linear(in_features=4096,out_features=10,bias=True)
#保存方式一:保存权重文件和模型结构
torch.save(vgg16,"vgg16_method1.pth")
#保存方式二(官方推荐),实际上保存的是权重文件,以字典方式存储:
torch.save(vgg16.state_dict(),"vgg16_method2.pth")

而如果我们想要取出我们已经保存的模型,就可以:

#方式一加载保存的模型
vgg16_method1=torch.load("vgg16_method1.pth")
#方式二加载保存的权重文件
vgg16_method2=torch.load("vgg16_method2.pth")
vgg16=torchvision.models.vgg16()
vgg16.load_state_dict(vgg16_method2)

相关文章:

5_现有网络模型的使用

教程:现有网络模型的使用及修改_哔哩哔哩_bilibili 官方网址:https://pytorch.org/vision/stable/models.html#classification 初识网络模型 pytorch为我们提供了许多已经构造好的网络模型,我们只要将它们加载进来,就可以直接使…...

软件安全测试报告内容和作用简析,软件测试服务供应商推荐

在数字化时代,软件安全问题愈发凸显,安全测试显得尤为重要。软件安全测试报告是对软件系统在安全性方面进行评估和分析后的书面文件。该报告通常包含测试过程、测试发现、漏洞描述、风险评估及改进建议等重要信息。报告的目的是为了帮助开发团队及时发现…...

算法板子:树形DP、树的DFS——树的重心

思想&#xff1a; 代码&#xff1a; #include <iostream> #include <cstring> using namespace std;const int N 1e5 10;// vis标记当前节点是否被访问过; vis[1]true代表编号为1的节点被访问过 bool vis[N]; // h数组为邻接表; h数组上的每个坑位都串了一个单链…...

在C语言中,联合体或共用体(union )是一种特殊的数据类型,允许在相同的内存位置存储不同的数据类型。

在C语言中&#xff0c;union 是一种特殊的数据类型&#xff0c;允许在相同的内存位置存储不同的数据类型。这意味着 union 中的所有成员共享同一块内存空间&#xff0c;因此它们之间会相互覆盖。在你给出的 Acceleration_type union 定义中&#xff0c;包含了三种不同类型的成员…...

MS2201以太网收发电路

MS2201 是吉比特以太网收发器电路&#xff0c;可以实现超高速度的 全双工数据传输。它的通信遵从 IEEE 802.3 Gigabit Ethernet 协议 中的 10 比特接口的时序要求协议。 MS2201 支持数据传输速率从 1Gbps 到 1.85Gbps 。 主要特点 ◼ 电源电压&#xff1a; 2.5V 、 3.3V …...

乐乐音乐Kotlin版

简介 乐乐音乐Kotlin版&#xff0c;主要是基于ExoPlayer框架开发的Android音乐播放器&#xff0c;它支持lrc歌词和动感歌词(ksc歌词、krc歌词、trc歌词、zrce歌词和hrc歌词等)、多种格式歌词转换器及制作动感歌词、翻译歌词和音译歌词。 编译环境 Android Studio Jellyfish | …...

C语言——预处理和指针

C语言——预处理和指针 预处理宏宏定义宏的作用域带参的宏 文件包含条件编译 指针指针的概念指针的定义指针变量初始化指针一维整型数组 预处理 编程的流程分为&#xff1a;编辑、编译、运行、调试四个阶段&#xff1b; 预处理属于编译阶段&#xff0c;编译过程又可以分为&…...

iptables防火墙(一)

目录 1、Linux防火墙基础 2、iptables的四表五链结构 2.1 iptables的四表五链结构介绍 2.2 四表五链 2.2.1 四表 2.2.2 五链 2.3 包过滤的匹配流程 2.3.1 规则链之间匹配顺序 2.3.2 规则链内部的处理规则 2.3.3 数据包过滤的匹配流程 3、 编写防火墙规则 3.1 iptabe…...

(leetcode学习)50. Pow(x, n)

实现 pow(x, n) &#xff0c;即计算 x 的整数 n 次幂函数&#xff08;即&#xff0c;xn &#xff09;。 示例 1&#xff1a; 输入&#xff1a;x 2.00000, n 10 输出&#xff1a;1024.00000示例 2&#xff1a; 输入&#xff1a;x 2.10000, n 3 输出&#xff1a;9.26100示例 …...

QT 5.12.0 for Windows 安装包 QT静态库 采用源码静态编译生成

qt-5.12.0-static.zip 下载地址(资源整理不易&#xff0c;下载使用需付费&#xff0c;且文件较大&#xff0c;不能接受请勿浪费时间下载): 链接&#xff1a;https://pan.baidu.com/s/1ftfHFG_jGFwVaOAvBVrNFg?pwdtvtp 提取码&#xff1a;tvtp...

【生成式人工智能-三-promote 神奇咒语RL增强式学习RAG】

如何激发模型的能力 提示词 promotCoTRL 增强式学习Reforcement learning提供更多的资料提供一些范例Incontext- learning 任务拆解让模型自己检查错误让模型多次生成答案Tree of Thoughts让模型使用其他工具RAG写程序POT其他工具 让多个模型合作参考 在模型不变的情况下&#…...

C++连接oracle数据库连接字符串

//远程连接&#xff0c;需要安装oracle客户端sprintf(szConnect4, ("Provider OraOLEDB.Oracle.1; Password %s; Persist Security Info True; User ID %s; Data Source \"(DESCRIPTION (ADDRESS_LIST (ADDRESS (PROTOCOL TCP)(HOST %s)(PORT 1521)) )(CONN…...

判断字符串是否接近:深入解析及优化【字符串、哈希表、优化过程】

本文将详细解析解决这个问题的思路&#xff0c;并逐步优化实现方案。 问题描述 给定两个字符串 word1 和 word2&#xff0c;如果通过以下操作可以将 word1 转换为 word2&#xff0c;则认为它们是接近的&#xff1a; 交换任意两个现有字符。将一个现有字符的每次出现转换为另…...

C 和 C++ 中信号处理简单介绍

信号处理是编程中一个重要的主题&#xff0c;特别是在需要处理异步事件和错误情况的系统中。在 C 和 C 语言中&#xff0c;信号处理机制提供了一种优雅的方式来响应特定的系统事件&#xff0c;例如用户中断、异常情况或其他信号。在这里&#xff0c;我将详细介绍 C 和 C 中信号…...

什么是云边协同?

当今信息技术高速发展的时代&#xff0c;"云边协同"&#xff08;Edge Cloud Collaboration&#xff09;已经成为一个备受关注的话题。它涉及到云计算和边缘计算的结合&#xff0c;为数据处理、存储和应用提供了全新的可能性。本文将介绍云边协同的概念、优势以及在不…...

YOLOv5改进 | 主干网络 | 将backbone替换为MobileNetV2【小白必备教程+附完整代码】

秋招面试专栏推荐 &#xff1a;深度学习算法工程师面试问题总结【百面算法工程师】——点击即可跳转 &#x1f4a1;&#x1f4a1;&#x1f4a1;本专栏所有程序均经过测试&#xff0c;可成功执行&#x1f4a1;&#x1f4a1;&#x1f4a1; 专栏目录&#xff1a; 《YOLOv5入门 改…...

ARMxy边缘计算网关用于过程控制子系统

在现代工业生产中&#xff0c;过程控制系统的优化对于提高生产效率、保证产品质量、降低能源消耗等方面都具有重要意义。而 ARMxy 工控机作为一种高性能、高可靠性的工业控制设备&#xff0c;正逐渐成为过程控制系统优化的新选择。 ARMxy 工控机采用了先进的 ARM 架构处理器&am…...

Python | TypeError: unsupported operand type(s) for +=: ‘int’ and ‘str’

Python | TypeError: unsupported operand type(s) for : ‘int’ and ‘str’&#xff1a;深度解析 在Python编程中&#xff0c;遇到“TypeError: unsupported operand type(s) for : ‘int’ and ‘str’”这类错误通常意味着你尝试将一个整数&#xff08;int&#xff09;和…...

什么是开源什么是闭源?以及它们之间的关系

开源软件&#xff08;Open Source Software&#xff09; 定义&#xff1a;开源软件是指其源代码可以被公众访问和使用的软件。用户可以查看、修改和增强软件的源代码。 许可&#xff1a;通常遵循特定的开源许可证&#xff0c;如GNU通用公共许可证&#xff08;GPL&#xff09;、…...

SpringBoot+Mybatis Plus实际开发中的注解

SpringBoot+Mybatis Plus实际开发中的注解 实体类Service层Mapper层Controller层启动类配置类SpringBoot+Mybatis Plus实际开发中的注解 实体类 @Data : 底层实现了getter、setter、toString、hashCode、equals 和无参构造@AllArgsConstructor: 底层实现了有参构造@NoArgsCon…...

地震勘探——干扰波识别、井中地震时距曲线特点

目录 干扰波识别反射波地震勘探的干扰波 井中地震时距曲线特点 干扰波识别 有效波&#xff1a;可以用来解决所提出的地质任务的波&#xff1b;干扰波&#xff1a;所有妨碍辨认、追踪有效波的其他波。 地震勘探中&#xff0c;有效波和干扰波是相对的。例如&#xff0c;在反射波…...

基于FPGA的PID算法学习———实现PID比例控制算法

基于FPGA的PID算法学习 前言一、PID算法分析二、PID仿真分析1. PID代码2.PI代码3.P代码4.顶层5.测试文件6.仿真波形 总结 前言 学习内容&#xff1a;参考网站&#xff1a; PID算法控制 PID即&#xff1a;Proportional&#xff08;比例&#xff09;、Integral&#xff08;积分&…...

线程与协程

1. 线程与协程 1.1. “函数调用级别”的切换、上下文切换 1. 函数调用级别的切换 “函数调用级别的切换”是指&#xff1a;像函数调用/返回一样轻量地完成任务切换。 举例说明&#xff1a; 当你在程序中写一个函数调用&#xff1a; funcA() 然后 funcA 执行完后返回&…...

4. TypeScript 类型推断与类型组合

一、类型推断 (一) 什么是类型推断 TypeScript 的类型推断会根据变量、函数返回值、对象和数组的赋值和使用方式&#xff0c;自动确定它们的类型。 这一特性减少了显式类型注解的需要&#xff0c;在保持类型安全的同时简化了代码。通过分析上下文和初始值&#xff0c;TypeSc…...

Golang——9、反射和文件操作

反射和文件操作 1、反射1.1、reflect.TypeOf()获取任意值的类型对象1.2、reflect.ValueOf()1.3、结构体反射 2、文件操作2.1、os.Open()打开文件2.2、方式一&#xff1a;使用Read()读取文件2.3、方式二&#xff1a;bufio读取文件2.4、方式三&#xff1a;os.ReadFile读取2.5、写…...

Proxmox Mail Gateway安装指南:从零开始配置高效邮件过滤系统

&#x1f49d;&#x1f49d;&#x1f49d;欢迎莅临我的博客&#xff0c;很高兴能够在这里和您见面&#xff01;希望您在这里可以感受到一份轻松愉快的氛围&#xff0c;不仅可以获得有趣的内容和知识&#xff0c;也可以畅所欲言、分享您的想法和见解。 推荐&#xff1a;「storms…...

[论文阅读]TrustRAG: Enhancing Robustness and Trustworthiness in RAG

TrustRAG: Enhancing Robustness and Trustworthiness in RAG [2501.00879] TrustRAG: Enhancing Robustness and Trustworthiness in Retrieval-Augmented Generation 代码&#xff1a;HuichiZhou/TrustRAG: Code for "TrustRAG: Enhancing Robustness and Trustworthin…...

从物理机到云原生:全面解析计算虚拟化技术的演进与应用

前言&#xff1a;我的虚拟化技术探索之旅 我最早接触"虚拟机"的概念是从Java开始的——JVM&#xff08;Java Virtual Machine&#xff09;让"一次编写&#xff0c;到处运行"成为可能。这个软件层面的虚拟化让我着迷&#xff0c;但直到后来接触VMware和Doc…...

归并排序:分治思想的高效排序

目录 基本原理 流程图解 实现方法 递归实现 非递归实现 演示过程 时间复杂度 基本原理 归并排序(Merge Sort)是一种基于分治思想的排序算法&#xff0c;由约翰冯诺伊曼在1945年提出。其核心思想包括&#xff1a; 分割(Divide)&#xff1a;将待排序数组递归地分成两个子…...

NineData数据库DevOps功能全面支持百度智能云向量数据库 VectorDB,助力企业 AI 应用高效落地

NineData 的数据库 DevOps 解决方案已完成对百度智能云向量数据库 VectorDB 的全链路适配&#xff0c;成为国内首批提供 VectorDB 原生操作能力的服务商。此次合作聚焦 AI 开发核心场景&#xff0c;通过标准化 SQL 工作台与细粒度权限管控两大能力&#xff0c;助力企业安全高效…...