当前位置: 首页 > news >正文

分类网络搭建示例

搭建CNN网络

本章我们来学习一下如何搭建网络,初始化方法,模型的保存,预训练模型的加载方法。本专栏需要搭建的是对分类性能的测试,所以这里我们只以VGG为例。

请注意,这里定义的只是一个简陋的版本,后续一些经典网络的学习,我们会在另外单独去开一个专栏讲解。

1. 网络搭建

在PyTorch中,你可以使用 torchvision.models 中的 vgg16 来加载预定义的VGG16模型,也可以手动定义。以下是手动定义的一个简化版本:

import torch
import torch.nn as nnclass VGG16(nn.Module):def __init__(self, num_classes=1000):super(VGG16, self).__init__()self.features = nn.Sequential(nn.Conv2d(3, 64, kernel_size=3, padding=1),nn.ReLU(inplace=True),nn.Conv2d(64, 64, kernel_size=3, padding=1),nn.ReLU(inplace=True),nn.MaxPool2d(kernel_size=2, stride=2),nn.Conv2d(64, 128, kernel_size=3, padding=1),nn.ReLU(inplace=True),nn.Conv2d(128, 128, kernel_size=3, padding=1),nn.ReLU(inplace=True),nn.MaxPool2d(kernel_size=2, stride=2),nn.Conv2d(128, 256, kernel_size=3, padding=1),nn.ReLU(inplace=True),nn.Conv2d(256, 256, kernel_size=3, padding=1),nn.ReLU(inplace=True),nn.Conv2d(256, 256, kernel_size=3, padding=1),nn.ReLU(inplace=True),nn.MaxPool2d(kernel_size=2, stride=2),nn.Conv2d(256, 512, kernel_size=3, padding=1),nn.ReLU(inplace=True),nn.Conv2d(512, 512, kernel_size=3, padding=1),nn.ReLU(inplace=True),nn.Conv2d(512, 512, kernel_size=3, padding=1),nn.ReLU(inplace=True),nn.MaxPool2d(kernel_size=2, stride=2),nn.Conv2d(512, 512, kernel_size=3, padding=1),nn.ReLU(inplace=True),nn.Conv2d(512, 512, kernel_size=3, padding=1),nn.ReLU(inplace=True),nn.Conv2d(512, 512, kernel_size=3, padding=1),nn.ReLU(inplace=True),nn.MaxPool2d(kernel_size=2, stride=2),)self.avgpool = nn.AdaptiveAvgPool2d((7, 7))self.classifier = nn.Sequential(nn.Linear(512 * 7 * 7, 4096),nn.ReLU(inplace=True),nn.Dropout(),nn.Linear(4096, 4096),nn.ReLU(inplace=True),nn.Dropout(),nn.Linear(4096, num_classes),)def forward(self, x):x = self.features(x)x = self.avgpool(x)x = torch.flatten(x, 1)x = self.classifier(x)return x

2. 初始化方法

在这里,我们不再手动初始化每一层,因为PyTorch的默认初始化通常足够好。你可以选择手动初始化,如果需要,可以使用 torch.nn.init 中的不同方法。

3. 模型的保存

使用 torch.save 保存VGG16模型:

vgg16 = VGG16()torch.save(vgg16.state_dict(), 'vgg16_model.pth')

4. 预训练模型的加载

要加载预训练的VGG16模型,你可以使用 torchvision.models 中的 vgg16(pretrained=True),或者手动加载预训练权重:

vgg16 = VGG16()vgg16.load_state_dict(torch.load('pretrained_vgg16.pth'))

请确保路径 'pretrained_vgg16.pth' 是你预训练模型文件的实际路径。你可以从PyTorch的官方模型库或其他来源下载预训练权重。

上面是最简单的一种模型全部加载的方式,但也有一些情况下,只是想加载其中一部分层的参数。剩下一部分由于已经改变参数了,无法加载预训练模型,所以要选择随机初始化。 、

这里我们来观察网络怎么去表示的:

if __name__ == "__main__":model = VGG16()for name, value in model.named_parameters():print(name)

下面就是控制台打印出的部分信息。 

这两行的输出就是打印网络层的名字,实际上加载预训练模型时,也是按照这个名字来加载的。

# 加载预训练 VGG16 模型的参数
pretrained_dict = torch.load('pretrained_vgg16.pth')# 剔除预训练模型中全连接层的参数
pretrained_dict.pop('classifier.0.weight')
pretrained_dict.pop('classifier.0.bias')
pretrained_dict.pop('classifier.3.weight')
pretrained_dict.pop('classifier.3.bias')
pretrained_dict.pop('classifier.6.weight')
pretrained_dict.pop('classifier.6.bias')# 获取自定义模型的参数字典
model_dict = model.state_dict()# 更新自定义模型的参数字典,加载预训练模型的参数值
model_dict.update(pretrained_dict)# 加载更新后的参数字典到自定义模型中
model.load_state_dict(model_dict)

自己定义的一些层是不会出现在pretrained_dict中,因此会将其剔除,从而只加载了 pretrained_dict中有的层。

总结

本章只是对网络的定义进行一个简单的示例,具体的部分我们会在另外一个专栏讲解,这里只是为了让读者了解网络定义的流程。在实际项目中,通常需要更详细的网络结构,包括适当的初始化方法、损失函数的选择、优化器的设置等。如果读者了解掌握了基本的网络定义过程,你可以在本专栏中深入讲解这些方面,以及如何训练和评估模型等内容。

相关文章:

分类网络搭建示例

搭建CNN网络 本章我们来学习一下如何搭建网络,初始化方法,模型的保存,预训练模型的加载方法。本专栏需要搭建的是对分类性能的测试,所以这里我们只以VGG为例。 请注意,这里定义的只是一个简陋的版本,后续一…...

为 Ubuntu 虚拟机构建 SSH 服务器

以校园网环境和VMware为例,关键步骤如下: 安装 SSH 服务: 打开 Ubuntu 虚拟机。打开终端。输入命令 sudo apt-get update 更新软件包列表。输入命令 sudo apt-get install openssh-server 安装 SSH 服务。 配置 SSH 服务: 编辑配…...

SpringBoot--中间件技术-2:整合redis,redis实战小案例,springboot cache,cache简化redis的实现,含代码

SpringBoot整合Redis 实现步骤 导pom文件坐标 <!--redis依赖--> <dependency><groupId>org.springframework.boot</groupId><artifactId>spring-boot-starter-data-redis</artifactId> </dependency>yaml主配置文件&#xff0c;配置…...

linux rsyslog配置文件详解

1.rsyslog配置文件简介 linux rsyslog配置文件/etc/rsyslog.conf分为三部分:MODULES、GLOBAL DIRECTIVES、RULES ryslog模块说明 模块说明MODULES指定接收日志的协议和端口。若要配置日志服务器,则需要将相应的配置项注释去掉。GLOBAL DIRECTIVES主要用来配置日志模版。指定…...

wordpress是什么?快速搭网站经验分享

​作者主页 &#x1f4da;lovewold少个r博客主页 ⚠️本文重点&#xff1a;c入门第一个程序和基本知识讲解 &#x1f449;【C-C入门系列专栏】&#xff1a;博客文章专栏传送门 &#x1f604;每日一言&#xff1a;宁静是一片强大而治愈的神奇海洋&#xff01; 目录 前言 wordp…...

排序 算法(第4版)

本博客参考算法&#xff08;第4版&#xff09;&#xff1a;算法&#xff08;第4版&#xff09; - LeetBook - 力扣&#xff08;LeetCode&#xff09;全球极客挚爱的技术成长平台 本文用Java实现相关算法。 我们关注的主要对象是重新排列数组元素的算法&#xff0c;其中每个元素…...

asp.net 在线音乐网站系统VS开发sqlserver数据库web结构c#编程Microsoft Visual Studio

一、源码特点 asp.net 在线音乐网站系统是一套完善的web设计管理系统&#xff0c;系统具有完整的源代码和数据库&#xff0c;系统主要采用B/S模式开发。开发环境为vs2010&#xff0c;数据库为sqlserver2008&#xff0c;使用c#语言 开发 asp.net 在线音乐网站系统1 应用…...

ElastaticSearch -- es之Filters aggregation 先过滤再聚合

使用场景 使用es时&#xff0c;有时我们需要先过滤后再聚合&#xff0c;但如果直接在query的filter中过滤&#xff0c;不止会影响到一个聚合&#xff0c;还会影响到其他的聚合结果。 比如&#xff0c;我们想要统计深圳市某个品牌的总销售额&#xff0c;以及该品牌的女款衣服的…...

如何把一个接口设计好?

如何把一个接口设计好&#xff1f; 如何设计一个接口&#xff1f;是在我们日常开发或者面试时经常问及的一个话题。很多人觉得这不就是CRUD&#xff0c;能实现不就行了。单纯实现来说&#xff0c;并非难事&#xff0c;但要做到易用、易扩展、易维护并不是一件简单的事。这里并…...

mini-vue 的设计

mini-vue 的设计 mini-vue 使用流程与结果预览&#xff1a; <!DOCTYPE html> <html lang"en"><head><meta charset"UTF-8" /><meta http-equiv"X-UA-Compatible" content"IEedge" /><meta name&qu…...

React整理杂记(一)

1.React三项依赖 1.react.js -> 核心代码 2.react-dom.js -> 渲染成dom 3.babel.js->非必须&#xff0c;将jsx转为js 类组件中直接定义的方法&#xff0c;都属于严格模式下 this的绑定可以放到constructor(){}中 2. JSX语法 1.可以直接插入的元素&#xff1a; num…...

[100天算法】-统计封闭岛屿的数目(day 74)

题目描述 有一个二维矩阵 grid &#xff0c;每个位置要么是陆地&#xff08;记号为 0 &#xff09;要么是水域&#xff08;记号为 1 &#xff09;。我们从一块陆地出发&#xff0c;每次可以往上下左右 4 个方向相邻区域走&#xff0c;能走到的所有陆地区域&#xff0c;我们将其…...

esp32-rust-std-examples-blinky

以下为在 ESP-IDF (FreeRTOS) 上运行的 blinky 示例&#xff1a; https://github.com/esp-rs/esp-idf-hal/blob/master/examples/blinky.rs //! Blinks an LED //! //! This assumes that a LED is connected to GPIO4. //! Depending on your target and the board you are …...

【docker容器技术与K8s】

【docker容器技术与K8s】 一、Docker容器技术 1、Docker的学习路线 &#xff08;1&#xff09;学习Docker基本命令&#xff08;容器管理和镜像管理&#xff09; &#xff08;2&#xff09;学习使用Docker搭建常用软件 &#xff08;3&#xff09;学习Docker网络模式 启动容器的…...

RT-DTER 引入用于低分辨率图像和小物体的新 CNN 模块 SPD-Conv

论文地址:https://arxiv.org/pdf/2208.03641v1.pdf 代码地址:https://github.com/labsaint/spd-conv 卷积神经网络(CNN)在图像分类、目标检测等计算机视觉任务中取得了巨大的成功。然而,在图像分辨率较低或对象较小的更困难的任务中,它们的性能会迅速下降。 这源于现有CNN…...

Folw + Room 实现自动观察数据库的刷新

1、Room &#xff1a;定义数据结构、创建数据库 // 定义实体 Entity data class TestModel ()// 定义数据库 Dao interface TestDao { Query("SELECT * FROM TestTable") fun getAll(): List<TestModel> }// 获取数据库 abstract class TestDatabase: RoomDat…...

黑马程序员微服务Docker实用篇

Docker实用篇 0.学习目标 1.初识Docker 1.1.什么是Docker 微服务虽然具备各种各样的优势&#xff0c;但服务的拆分通用给部署带来了很大的麻烦。 分布式系统中&#xff0c;依赖的组件非常多&#xff0c;不同组件之间部署时往往会产生一些冲突。在数百上千台服务中重复部署…...

虚拟化服务器+华为防火墙+kiwi_syslog访问留痕

一、适用场景 1、大中型企业需要对接入用户的访问进行记录时&#xff0c;以前用3CDaemon时&#xff0c;只能用于小型网络当中&#xff0c;记录的数据量太大时&#xff0c;本例采用破解版的kiwi_syslog。 2、当网监、公安查到有非法访问时&#xff0c;可提供基于五元组的外网访…...

FlinkSQL聚合函数(Aggregate Function)详解

使用场景&#xff1a; 聚合函数即 UDAF&#xff0c;常⽤于进多条数据&#xff0c;出⼀条数据的场景。 上图展示了⼀个 聚合函数的例⼦ 以及 聚合函数包含的重要⽅法。 案例场景&#xff1a; 关于饮料的表&#xff0c;有三个字段&#xff0c;分别是 id、name、price&#xff0…...

TensorFlow学习笔记--(3)张量的常用运算函数

损失函数及求偏导 通过 tf.GradientTape 函数来指定损失函数的变量以及表达式 最后通过 gradient(%损失函数%,%偏导对象%) 来获取求偏导的结果 独热编码 给出一组特征值 来对图像进行分类 可以用独热编码 0的概率是第0种 1的概率是第1种 0的概率是第二种 tf.one_hot(%某标签…...

生成xcframework

打包 XCFramework 的方法 XCFramework 是苹果推出的一种多平台二进制分发格式&#xff0c;可以包含多个架构和平台的代码。打包 XCFramework 通常用于分发库或框架。 使用 Xcode 命令行工具打包 通过 xcodebuild 命令可以打包 XCFramework。确保项目已经配置好需要支持的平台…...

Appium+python自动化(十六)- ADB命令

简介 Android 调试桥(adb)是多种用途的工具&#xff0c;该工具可以帮助你你管理设备或模拟器 的状态。 adb ( Android Debug Bridge)是一个通用命令行工具&#xff0c;其允许您与模拟器实例或连接的 Android 设备进行通信。它可为各种设备操作提供便利&#xff0c;如安装和调试…...

【位运算】消失的两个数字(hard)

消失的两个数字&#xff08;hard&#xff09; 题⽬描述&#xff1a;解法&#xff08;位运算&#xff09;&#xff1a;Java 算法代码&#xff1a;更简便代码 题⽬链接&#xff1a;⾯试题 17.19. 消失的两个数字 题⽬描述&#xff1a; 给定⼀个数组&#xff0c;包含从 1 到 N 所有…...

为什么需要建设工程项目管理?工程项目管理有哪些亮点功能?

在建筑行业&#xff0c;项目管理的重要性不言而喻。随着工程规模的扩大、技术复杂度的提升&#xff0c;传统的管理模式已经难以满足现代工程的需求。过去&#xff0c;许多企业依赖手工记录、口头沟通和分散的信息管理&#xff0c;导致效率低下、成本失控、风险频发。例如&#…...

高等数学(下)题型笔记(八)空间解析几何与向量代数

目录 0 前言 1 向量的点乘 1.1 基本公式 1.2 例题 2 向量的叉乘 2.1 基础知识 2.2 例题 3 空间平面方程 3.1 基础知识 3.2 例题 4 空间直线方程 4.1 基础知识 4.2 例题 5 旋转曲面及其方程 5.1 基础知识 5.2 例题 6 空间曲面的法线与切平面 6.1 基础知识 6.2…...

Ascend NPU上适配Step-Audio模型

1 概述 1.1 简述 Step-Audio 是业界首个集语音理解与生成控制一体化的产品级开源实时语音对话系统&#xff0c;支持多语言对话&#xff08;如 中文&#xff0c;英文&#xff0c;日语&#xff09;&#xff0c;语音情感&#xff08;如 开心&#xff0c;悲伤&#xff09;&#x…...

大模型多显卡多服务器并行计算方法与实践指南

一、分布式训练概述 大规模语言模型的训练通常需要分布式计算技术,以解决单机资源不足的问题。分布式训练主要分为两种模式: 数据并行:将数据分片到不同设备,每个设备拥有完整的模型副本 模型并行:将模型分割到不同设备,每个设备处理部分模型计算 现代大模型训练通常结合…...

网站指纹识别

网站指纹识别 网站的最基本组成&#xff1a;服务器&#xff08;操作系统&#xff09;、中间件&#xff08;web容器&#xff09;、脚本语言、数据厍 为什么要了解这些&#xff1f;举个例子&#xff1a;发现了一个文件读取漏洞&#xff0c;我们需要读/etc/passwd&#xff0c;如…...

08. C#入门系列【类的基本概念】:开启编程世界的奇妙冒险

C#入门系列【类的基本概念】&#xff1a;开启编程世界的奇妙冒险 嘿&#xff0c;各位编程小白探险家&#xff01;欢迎来到 C# 的奇幻大陆&#xff01;今天咱们要深入探索这片大陆上至关重要的 “建筑”—— 类&#xff01;别害怕&#xff0c;跟着我&#xff0c;保准让你轻松搞…...

R 语言科研绘图第 55 期 --- 网络图-聚类

在发表科研论文的过程中&#xff0c;科研绘图是必不可少的&#xff0c;一张好看的图形会是文章很大的加分项。 为了便于使用&#xff0c;本系列文章介绍的所有绘图都已收录到了 sciRplot 项目中&#xff0c;获取方式&#xff1a; R 语言科研绘图模板 --- sciRplothttps://mp.…...