当前位置: 首页 > news >正文

分类网络搭建示例

搭建CNN网络

本章我们来学习一下如何搭建网络,初始化方法,模型的保存,预训练模型的加载方法。本专栏需要搭建的是对分类性能的测试,所以这里我们只以VGG为例。

请注意,这里定义的只是一个简陋的版本,后续一些经典网络的学习,我们会在另外单独去开一个专栏讲解。

1. 网络搭建

在PyTorch中,你可以使用 torchvision.models 中的 vgg16 来加载预定义的VGG16模型,也可以手动定义。以下是手动定义的一个简化版本:

import torch
import torch.nn as nnclass VGG16(nn.Module):def __init__(self, num_classes=1000):super(VGG16, self).__init__()self.features = nn.Sequential(nn.Conv2d(3, 64, kernel_size=3, padding=1),nn.ReLU(inplace=True),nn.Conv2d(64, 64, kernel_size=3, padding=1),nn.ReLU(inplace=True),nn.MaxPool2d(kernel_size=2, stride=2),nn.Conv2d(64, 128, kernel_size=3, padding=1),nn.ReLU(inplace=True),nn.Conv2d(128, 128, kernel_size=3, padding=1),nn.ReLU(inplace=True),nn.MaxPool2d(kernel_size=2, stride=2),nn.Conv2d(128, 256, kernel_size=3, padding=1),nn.ReLU(inplace=True),nn.Conv2d(256, 256, kernel_size=3, padding=1),nn.ReLU(inplace=True),nn.Conv2d(256, 256, kernel_size=3, padding=1),nn.ReLU(inplace=True),nn.MaxPool2d(kernel_size=2, stride=2),nn.Conv2d(256, 512, kernel_size=3, padding=1),nn.ReLU(inplace=True),nn.Conv2d(512, 512, kernel_size=3, padding=1),nn.ReLU(inplace=True),nn.Conv2d(512, 512, kernel_size=3, padding=1),nn.ReLU(inplace=True),nn.MaxPool2d(kernel_size=2, stride=2),nn.Conv2d(512, 512, kernel_size=3, padding=1),nn.ReLU(inplace=True),nn.Conv2d(512, 512, kernel_size=3, padding=1),nn.ReLU(inplace=True),nn.Conv2d(512, 512, kernel_size=3, padding=1),nn.ReLU(inplace=True),nn.MaxPool2d(kernel_size=2, stride=2),)self.avgpool = nn.AdaptiveAvgPool2d((7, 7))self.classifier = nn.Sequential(nn.Linear(512 * 7 * 7, 4096),nn.ReLU(inplace=True),nn.Dropout(),nn.Linear(4096, 4096),nn.ReLU(inplace=True),nn.Dropout(),nn.Linear(4096, num_classes),)def forward(self, x):x = self.features(x)x = self.avgpool(x)x = torch.flatten(x, 1)x = self.classifier(x)return x

2. 初始化方法

在这里,我们不再手动初始化每一层,因为PyTorch的默认初始化通常足够好。你可以选择手动初始化,如果需要,可以使用 torch.nn.init 中的不同方法。

3. 模型的保存

使用 torch.save 保存VGG16模型:

vgg16 = VGG16()torch.save(vgg16.state_dict(), 'vgg16_model.pth')

4. 预训练模型的加载

要加载预训练的VGG16模型,你可以使用 torchvision.models 中的 vgg16(pretrained=True),或者手动加载预训练权重:

vgg16 = VGG16()vgg16.load_state_dict(torch.load('pretrained_vgg16.pth'))

请确保路径 'pretrained_vgg16.pth' 是你预训练模型文件的实际路径。你可以从PyTorch的官方模型库或其他来源下载预训练权重。

上面是最简单的一种模型全部加载的方式,但也有一些情况下,只是想加载其中一部分层的参数。剩下一部分由于已经改变参数了,无法加载预训练模型,所以要选择随机初始化。 、

这里我们来观察网络怎么去表示的:

if __name__ == "__main__":model = VGG16()for name, value in model.named_parameters():print(name)

下面就是控制台打印出的部分信息。 

这两行的输出就是打印网络层的名字,实际上加载预训练模型时,也是按照这个名字来加载的。

# 加载预训练 VGG16 模型的参数
pretrained_dict = torch.load('pretrained_vgg16.pth')# 剔除预训练模型中全连接层的参数
pretrained_dict.pop('classifier.0.weight')
pretrained_dict.pop('classifier.0.bias')
pretrained_dict.pop('classifier.3.weight')
pretrained_dict.pop('classifier.3.bias')
pretrained_dict.pop('classifier.6.weight')
pretrained_dict.pop('classifier.6.bias')# 获取自定义模型的参数字典
model_dict = model.state_dict()# 更新自定义模型的参数字典,加载预训练模型的参数值
model_dict.update(pretrained_dict)# 加载更新后的参数字典到自定义模型中
model.load_state_dict(model_dict)

自己定义的一些层是不会出现在pretrained_dict中,因此会将其剔除,从而只加载了 pretrained_dict中有的层。

总结

本章只是对网络的定义进行一个简单的示例,具体的部分我们会在另外一个专栏讲解,这里只是为了让读者了解网络定义的流程。在实际项目中,通常需要更详细的网络结构,包括适当的初始化方法、损失函数的选择、优化器的设置等。如果读者了解掌握了基本的网络定义过程,你可以在本专栏中深入讲解这些方面,以及如何训练和评估模型等内容。

相关文章:

分类网络搭建示例

搭建CNN网络 本章我们来学习一下如何搭建网络,初始化方法,模型的保存,预训练模型的加载方法。本专栏需要搭建的是对分类性能的测试,所以这里我们只以VGG为例。 请注意,这里定义的只是一个简陋的版本,后续一…...

为 Ubuntu 虚拟机构建 SSH 服务器

以校园网环境和VMware为例,关键步骤如下: 安装 SSH 服务: 打开 Ubuntu 虚拟机。打开终端。输入命令 sudo apt-get update 更新软件包列表。输入命令 sudo apt-get install openssh-server 安装 SSH 服务。 配置 SSH 服务: 编辑配…...

SpringBoot--中间件技术-2:整合redis,redis实战小案例,springboot cache,cache简化redis的实现,含代码

SpringBoot整合Redis 实现步骤 导pom文件坐标 <!--redis依赖--> <dependency><groupId>org.springframework.boot</groupId><artifactId>spring-boot-starter-data-redis</artifactId> </dependency>yaml主配置文件&#xff0c;配置…...

linux rsyslog配置文件详解

1.rsyslog配置文件简介 linux rsyslog配置文件/etc/rsyslog.conf分为三部分:MODULES、GLOBAL DIRECTIVES、RULES ryslog模块说明 模块说明MODULES指定接收日志的协议和端口。若要配置日志服务器,则需要将相应的配置项注释去掉。GLOBAL DIRECTIVES主要用来配置日志模版。指定…...

wordpress是什么?快速搭网站经验分享

​作者主页 &#x1f4da;lovewold少个r博客主页 ⚠️本文重点&#xff1a;c入门第一个程序和基本知识讲解 &#x1f449;【C-C入门系列专栏】&#xff1a;博客文章专栏传送门 &#x1f604;每日一言&#xff1a;宁静是一片强大而治愈的神奇海洋&#xff01; 目录 前言 wordp…...

排序 算法(第4版)

本博客参考算法&#xff08;第4版&#xff09;&#xff1a;算法&#xff08;第4版&#xff09; - LeetBook - 力扣&#xff08;LeetCode&#xff09;全球极客挚爱的技术成长平台 本文用Java实现相关算法。 我们关注的主要对象是重新排列数组元素的算法&#xff0c;其中每个元素…...

asp.net 在线音乐网站系统VS开发sqlserver数据库web结构c#编程Microsoft Visual Studio

一、源码特点 asp.net 在线音乐网站系统是一套完善的web设计管理系统&#xff0c;系统具有完整的源代码和数据库&#xff0c;系统主要采用B/S模式开发。开发环境为vs2010&#xff0c;数据库为sqlserver2008&#xff0c;使用c#语言 开发 asp.net 在线音乐网站系统1 应用…...

ElastaticSearch -- es之Filters aggregation 先过滤再聚合

使用场景 使用es时&#xff0c;有时我们需要先过滤后再聚合&#xff0c;但如果直接在query的filter中过滤&#xff0c;不止会影响到一个聚合&#xff0c;还会影响到其他的聚合结果。 比如&#xff0c;我们想要统计深圳市某个品牌的总销售额&#xff0c;以及该品牌的女款衣服的…...

如何把一个接口设计好?

如何把一个接口设计好&#xff1f; 如何设计一个接口&#xff1f;是在我们日常开发或者面试时经常问及的一个话题。很多人觉得这不就是CRUD&#xff0c;能实现不就行了。单纯实现来说&#xff0c;并非难事&#xff0c;但要做到易用、易扩展、易维护并不是一件简单的事。这里并…...

mini-vue 的设计

mini-vue 的设计 mini-vue 使用流程与结果预览&#xff1a; <!DOCTYPE html> <html lang"en"><head><meta charset"UTF-8" /><meta http-equiv"X-UA-Compatible" content"IEedge" /><meta name&qu…...

React整理杂记(一)

1.React三项依赖 1.react.js -> 核心代码 2.react-dom.js -> 渲染成dom 3.babel.js->非必须&#xff0c;将jsx转为js 类组件中直接定义的方法&#xff0c;都属于严格模式下 this的绑定可以放到constructor(){}中 2. JSX语法 1.可以直接插入的元素&#xff1a; num…...

[100天算法】-统计封闭岛屿的数目(day 74)

题目描述 有一个二维矩阵 grid &#xff0c;每个位置要么是陆地&#xff08;记号为 0 &#xff09;要么是水域&#xff08;记号为 1 &#xff09;。我们从一块陆地出发&#xff0c;每次可以往上下左右 4 个方向相邻区域走&#xff0c;能走到的所有陆地区域&#xff0c;我们将其…...

esp32-rust-std-examples-blinky

以下为在 ESP-IDF (FreeRTOS) 上运行的 blinky 示例&#xff1a; https://github.com/esp-rs/esp-idf-hal/blob/master/examples/blinky.rs //! Blinks an LED //! //! This assumes that a LED is connected to GPIO4. //! Depending on your target and the board you are …...

【docker容器技术与K8s】

【docker容器技术与K8s】 一、Docker容器技术 1、Docker的学习路线 &#xff08;1&#xff09;学习Docker基本命令&#xff08;容器管理和镜像管理&#xff09; &#xff08;2&#xff09;学习使用Docker搭建常用软件 &#xff08;3&#xff09;学习Docker网络模式 启动容器的…...

RT-DTER 引入用于低分辨率图像和小物体的新 CNN 模块 SPD-Conv

论文地址:https://arxiv.org/pdf/2208.03641v1.pdf 代码地址:https://github.com/labsaint/spd-conv 卷积神经网络(CNN)在图像分类、目标检测等计算机视觉任务中取得了巨大的成功。然而,在图像分辨率较低或对象较小的更困难的任务中,它们的性能会迅速下降。 这源于现有CNN…...

Folw + Room 实现自动观察数据库的刷新

1、Room &#xff1a;定义数据结构、创建数据库 // 定义实体 Entity data class TestModel ()// 定义数据库 Dao interface TestDao { Query("SELECT * FROM TestTable") fun getAll(): List<TestModel> }// 获取数据库 abstract class TestDatabase: RoomDat…...

黑马程序员微服务Docker实用篇

Docker实用篇 0.学习目标 1.初识Docker 1.1.什么是Docker 微服务虽然具备各种各样的优势&#xff0c;但服务的拆分通用给部署带来了很大的麻烦。 分布式系统中&#xff0c;依赖的组件非常多&#xff0c;不同组件之间部署时往往会产生一些冲突。在数百上千台服务中重复部署…...

虚拟化服务器+华为防火墙+kiwi_syslog访问留痕

一、适用场景 1、大中型企业需要对接入用户的访问进行记录时&#xff0c;以前用3CDaemon时&#xff0c;只能用于小型网络当中&#xff0c;记录的数据量太大时&#xff0c;本例采用破解版的kiwi_syslog。 2、当网监、公安查到有非法访问时&#xff0c;可提供基于五元组的外网访…...

FlinkSQL聚合函数(Aggregate Function)详解

使用场景&#xff1a; 聚合函数即 UDAF&#xff0c;常⽤于进多条数据&#xff0c;出⼀条数据的场景。 上图展示了⼀个 聚合函数的例⼦ 以及 聚合函数包含的重要⽅法。 案例场景&#xff1a; 关于饮料的表&#xff0c;有三个字段&#xff0c;分别是 id、name、price&#xff0…...

TensorFlow学习笔记--(3)张量的常用运算函数

损失函数及求偏导 通过 tf.GradientTape 函数来指定损失函数的变量以及表达式 最后通过 gradient(%损失函数%,%偏导对象%) 来获取求偏导的结果 独热编码 给出一组特征值 来对图像进行分类 可以用独热编码 0的概率是第0种 1的概率是第1种 0的概率是第二种 tf.one_hot(%某标签…...

Python|GIF 解析与构建(5):手搓截屏和帧率控制

目录 Python&#xff5c;GIF 解析与构建&#xff08;5&#xff09;&#xff1a;手搓截屏和帧率控制 一、引言 二、技术实现&#xff1a;手搓截屏模块 2.1 核心原理 2.2 代码解析&#xff1a;ScreenshotData类 2.2.1 截图函数&#xff1a;capture_screen 三、技术实现&…...

C++_核心编程_多态案例二-制作饮品

#include <iostream> #include <string> using namespace std;/*制作饮品的大致流程为&#xff1a;煮水 - 冲泡 - 倒入杯中 - 加入辅料 利用多态技术实现本案例&#xff0c;提供抽象制作饮品基类&#xff0c;提供子类制作咖啡和茶叶*//*基类*/ class AbstractDr…...

LeetCode - 394. 字符串解码

题目 394. 字符串解码 - 力扣&#xff08;LeetCode&#xff09; 思路 使用两个栈&#xff1a;一个存储重复次数&#xff0c;一个存储字符串 遍历输入字符串&#xff1a; 数字处理&#xff1a;遇到数字时&#xff0c;累积计算重复次数左括号处理&#xff1a;保存当前状态&a…...

免费PDF转图片工具

免费PDF转图片工具 一款简单易用的PDF转图片工具&#xff0c;可以将PDF文件快速转换为高质量PNG图片。无需安装复杂的软件&#xff0c;也不需要在线上传文件&#xff0c;保护您的隐私。 工具截图 主要特点 &#x1f680; 快速转换&#xff1a;本地转换&#xff0c;无需等待上…...

【网络安全】开源系统getshell漏洞挖掘

审计过程&#xff1a; 在入口文件admin/index.php中&#xff1a; 用户可以通过m,c,a等参数控制加载的文件和方法&#xff0c;在app/system/entrance.php中存在重点代码&#xff1a; 当M_TYPE system并且M_MODULE include时&#xff0c;会设置常量PATH_OWN_FILE为PATH_APP.M_T…...

day36-多路IO复用

一、基本概念 &#xff08;服务器多客户端模型&#xff09; 定义&#xff1a;单线程或单进程同时监测若干个文件描述符是否可以执行IO操作的能力 作用&#xff1a;应用程序通常需要处理来自多条事件流中的事件&#xff0c;比如我现在用的电脑&#xff0c;需要同时处理键盘鼠标…...

TJCTF 2025

还以为是天津的。这个比较容易&#xff0c;虽然绕了点弯&#xff0c;可还是把CP AK了&#xff0c;不过我会的别人也会&#xff0c;还是没啥名次。记录一下吧。 Crypto bacon-bits with open(flag.txt) as f: flag f.read().strip() with open(text.txt) as t: text t.read…...

在Zenodo下载文件 用到googlecolab googledrive

方法&#xff1a;Figshare/Zenodo上的数据/文件下载不下来&#xff1f;尝试利用Google Colab &#xff1a;https://zhuanlan.zhihu.com/p/1898503078782674027 参考&#xff1a; 通过Colab&谷歌云下载Figshare数据&#xff0c;超级实用&#xff01;&#xff01;&#xff0…...

CppCon 2015 学习:REFLECTION TECHNIQUES IN C++

关于 Reflection&#xff08;反射&#xff09; 这个概念&#xff0c;总结一下&#xff1a; Reflection&#xff08;反射&#xff09;是什么&#xff1f; 反射是对类型的自我检查能力&#xff08;Introspection&#xff09; 可以查看类的成员变量、成员函数等信息。反射允许枚…...

性能优化中,多面体模型基本原理

1&#xff09;多面体编译技术是一种基于多面体模型的程序分析和优化技术&#xff0c;它将程序 中的语句实例、访问关系、依赖关系和调度等信息映射到多维空间中的几何对 象&#xff0c;通过对这些几何对象进行几何操作和线性代数计算来进行程序的分析和优 化。 其中&#xff0…...