当前位置: 首页 > news >正文

分类网络搭建示例

搭建CNN网络

本章我们来学习一下如何搭建网络,初始化方法,模型的保存,预训练模型的加载方法。本专栏需要搭建的是对分类性能的测试,所以这里我们只以VGG为例。

请注意,这里定义的只是一个简陋的版本,后续一些经典网络的学习,我们会在另外单独去开一个专栏讲解。

1. 网络搭建

在PyTorch中,你可以使用 torchvision.models 中的 vgg16 来加载预定义的VGG16模型,也可以手动定义。以下是手动定义的一个简化版本:

import torch
import torch.nn as nnclass VGG16(nn.Module):def __init__(self, num_classes=1000):super(VGG16, self).__init__()self.features = nn.Sequential(nn.Conv2d(3, 64, kernel_size=3, padding=1),nn.ReLU(inplace=True),nn.Conv2d(64, 64, kernel_size=3, padding=1),nn.ReLU(inplace=True),nn.MaxPool2d(kernel_size=2, stride=2),nn.Conv2d(64, 128, kernel_size=3, padding=1),nn.ReLU(inplace=True),nn.Conv2d(128, 128, kernel_size=3, padding=1),nn.ReLU(inplace=True),nn.MaxPool2d(kernel_size=2, stride=2),nn.Conv2d(128, 256, kernel_size=3, padding=1),nn.ReLU(inplace=True),nn.Conv2d(256, 256, kernel_size=3, padding=1),nn.ReLU(inplace=True),nn.Conv2d(256, 256, kernel_size=3, padding=1),nn.ReLU(inplace=True),nn.MaxPool2d(kernel_size=2, stride=2),nn.Conv2d(256, 512, kernel_size=3, padding=1),nn.ReLU(inplace=True),nn.Conv2d(512, 512, kernel_size=3, padding=1),nn.ReLU(inplace=True),nn.Conv2d(512, 512, kernel_size=3, padding=1),nn.ReLU(inplace=True),nn.MaxPool2d(kernel_size=2, stride=2),nn.Conv2d(512, 512, kernel_size=3, padding=1),nn.ReLU(inplace=True),nn.Conv2d(512, 512, kernel_size=3, padding=1),nn.ReLU(inplace=True),nn.Conv2d(512, 512, kernel_size=3, padding=1),nn.ReLU(inplace=True),nn.MaxPool2d(kernel_size=2, stride=2),)self.avgpool = nn.AdaptiveAvgPool2d((7, 7))self.classifier = nn.Sequential(nn.Linear(512 * 7 * 7, 4096),nn.ReLU(inplace=True),nn.Dropout(),nn.Linear(4096, 4096),nn.ReLU(inplace=True),nn.Dropout(),nn.Linear(4096, num_classes),)def forward(self, x):x = self.features(x)x = self.avgpool(x)x = torch.flatten(x, 1)x = self.classifier(x)return x

2. 初始化方法

在这里,我们不再手动初始化每一层,因为PyTorch的默认初始化通常足够好。你可以选择手动初始化,如果需要,可以使用 torch.nn.init 中的不同方法。

3. 模型的保存

使用 torch.save 保存VGG16模型:

vgg16 = VGG16()torch.save(vgg16.state_dict(), 'vgg16_model.pth')

4. 预训练模型的加载

要加载预训练的VGG16模型,你可以使用 torchvision.models 中的 vgg16(pretrained=True),或者手动加载预训练权重:

vgg16 = VGG16()vgg16.load_state_dict(torch.load('pretrained_vgg16.pth'))

请确保路径 'pretrained_vgg16.pth' 是你预训练模型文件的实际路径。你可以从PyTorch的官方模型库或其他来源下载预训练权重。

上面是最简单的一种模型全部加载的方式,但也有一些情况下,只是想加载其中一部分层的参数。剩下一部分由于已经改变参数了,无法加载预训练模型,所以要选择随机初始化。 、

这里我们来观察网络怎么去表示的:

if __name__ == "__main__":model = VGG16()for name, value in model.named_parameters():print(name)

下面就是控制台打印出的部分信息。 

这两行的输出就是打印网络层的名字,实际上加载预训练模型时,也是按照这个名字来加载的。

# 加载预训练 VGG16 模型的参数
pretrained_dict = torch.load('pretrained_vgg16.pth')# 剔除预训练模型中全连接层的参数
pretrained_dict.pop('classifier.0.weight')
pretrained_dict.pop('classifier.0.bias')
pretrained_dict.pop('classifier.3.weight')
pretrained_dict.pop('classifier.3.bias')
pretrained_dict.pop('classifier.6.weight')
pretrained_dict.pop('classifier.6.bias')# 获取自定义模型的参数字典
model_dict = model.state_dict()# 更新自定义模型的参数字典,加载预训练模型的参数值
model_dict.update(pretrained_dict)# 加载更新后的参数字典到自定义模型中
model.load_state_dict(model_dict)

自己定义的一些层是不会出现在pretrained_dict中,因此会将其剔除,从而只加载了 pretrained_dict中有的层。

总结

本章只是对网络的定义进行一个简单的示例,具体的部分我们会在另外一个专栏讲解,这里只是为了让读者了解网络定义的流程。在实际项目中,通常需要更详细的网络结构,包括适当的初始化方法、损失函数的选择、优化器的设置等。如果读者了解掌握了基本的网络定义过程,你可以在本专栏中深入讲解这些方面,以及如何训练和评估模型等内容。

相关文章:

分类网络搭建示例

搭建CNN网络 本章我们来学习一下如何搭建网络,初始化方法,模型的保存,预训练模型的加载方法。本专栏需要搭建的是对分类性能的测试,所以这里我们只以VGG为例。 请注意,这里定义的只是一个简陋的版本,后续一…...

为 Ubuntu 虚拟机构建 SSH 服务器

以校园网环境和VMware为例,关键步骤如下: 安装 SSH 服务: 打开 Ubuntu 虚拟机。打开终端。输入命令 sudo apt-get update 更新软件包列表。输入命令 sudo apt-get install openssh-server 安装 SSH 服务。 配置 SSH 服务: 编辑配…...

SpringBoot--中间件技术-2:整合redis,redis实战小案例,springboot cache,cache简化redis的实现,含代码

SpringBoot整合Redis 实现步骤 导pom文件坐标 <!--redis依赖--> <dependency><groupId>org.springframework.boot</groupId><artifactId>spring-boot-starter-data-redis</artifactId> </dependency>yaml主配置文件&#xff0c;配置…...

linux rsyslog配置文件详解

1.rsyslog配置文件简介 linux rsyslog配置文件/etc/rsyslog.conf分为三部分:MODULES、GLOBAL DIRECTIVES、RULES ryslog模块说明 模块说明MODULES指定接收日志的协议和端口。若要配置日志服务器,则需要将相应的配置项注释去掉。GLOBAL DIRECTIVES主要用来配置日志模版。指定…...

wordpress是什么?快速搭网站经验分享

​作者主页 &#x1f4da;lovewold少个r博客主页 ⚠️本文重点&#xff1a;c入门第一个程序和基本知识讲解 &#x1f449;【C-C入门系列专栏】&#xff1a;博客文章专栏传送门 &#x1f604;每日一言&#xff1a;宁静是一片强大而治愈的神奇海洋&#xff01; 目录 前言 wordp…...

排序 算法(第4版)

本博客参考算法&#xff08;第4版&#xff09;&#xff1a;算法&#xff08;第4版&#xff09; - LeetBook - 力扣&#xff08;LeetCode&#xff09;全球极客挚爱的技术成长平台 本文用Java实现相关算法。 我们关注的主要对象是重新排列数组元素的算法&#xff0c;其中每个元素…...

asp.net 在线音乐网站系统VS开发sqlserver数据库web结构c#编程Microsoft Visual Studio

一、源码特点 asp.net 在线音乐网站系统是一套完善的web设计管理系统&#xff0c;系统具有完整的源代码和数据库&#xff0c;系统主要采用B/S模式开发。开发环境为vs2010&#xff0c;数据库为sqlserver2008&#xff0c;使用c#语言 开发 asp.net 在线音乐网站系统1 应用…...

ElastaticSearch -- es之Filters aggregation 先过滤再聚合

使用场景 使用es时&#xff0c;有时我们需要先过滤后再聚合&#xff0c;但如果直接在query的filter中过滤&#xff0c;不止会影响到一个聚合&#xff0c;还会影响到其他的聚合结果。 比如&#xff0c;我们想要统计深圳市某个品牌的总销售额&#xff0c;以及该品牌的女款衣服的…...

如何把一个接口设计好?

如何把一个接口设计好&#xff1f; 如何设计一个接口&#xff1f;是在我们日常开发或者面试时经常问及的一个话题。很多人觉得这不就是CRUD&#xff0c;能实现不就行了。单纯实现来说&#xff0c;并非难事&#xff0c;但要做到易用、易扩展、易维护并不是一件简单的事。这里并…...

mini-vue 的设计

mini-vue 的设计 mini-vue 使用流程与结果预览&#xff1a; <!DOCTYPE html> <html lang"en"><head><meta charset"UTF-8" /><meta http-equiv"X-UA-Compatible" content"IEedge" /><meta name&qu…...

React整理杂记(一)

1.React三项依赖 1.react.js -> 核心代码 2.react-dom.js -> 渲染成dom 3.babel.js->非必须&#xff0c;将jsx转为js 类组件中直接定义的方法&#xff0c;都属于严格模式下 this的绑定可以放到constructor(){}中 2. JSX语法 1.可以直接插入的元素&#xff1a; num…...

[100天算法】-统计封闭岛屿的数目(day 74)

题目描述 有一个二维矩阵 grid &#xff0c;每个位置要么是陆地&#xff08;记号为 0 &#xff09;要么是水域&#xff08;记号为 1 &#xff09;。我们从一块陆地出发&#xff0c;每次可以往上下左右 4 个方向相邻区域走&#xff0c;能走到的所有陆地区域&#xff0c;我们将其…...

esp32-rust-std-examples-blinky

以下为在 ESP-IDF (FreeRTOS) 上运行的 blinky 示例&#xff1a; https://github.com/esp-rs/esp-idf-hal/blob/master/examples/blinky.rs //! Blinks an LED //! //! This assumes that a LED is connected to GPIO4. //! Depending on your target and the board you are …...

【docker容器技术与K8s】

【docker容器技术与K8s】 一、Docker容器技术 1、Docker的学习路线 &#xff08;1&#xff09;学习Docker基本命令&#xff08;容器管理和镜像管理&#xff09; &#xff08;2&#xff09;学习使用Docker搭建常用软件 &#xff08;3&#xff09;学习Docker网络模式 启动容器的…...

RT-DTER 引入用于低分辨率图像和小物体的新 CNN 模块 SPD-Conv

论文地址:https://arxiv.org/pdf/2208.03641v1.pdf 代码地址:https://github.com/labsaint/spd-conv 卷积神经网络(CNN)在图像分类、目标检测等计算机视觉任务中取得了巨大的成功。然而,在图像分辨率较低或对象较小的更困难的任务中,它们的性能会迅速下降。 这源于现有CNN…...

Folw + Room 实现自动观察数据库的刷新

1、Room &#xff1a;定义数据结构、创建数据库 // 定义实体 Entity data class TestModel ()// 定义数据库 Dao interface TestDao { Query("SELECT * FROM TestTable") fun getAll(): List<TestModel> }// 获取数据库 abstract class TestDatabase: RoomDat…...

黑马程序员微服务Docker实用篇

Docker实用篇 0.学习目标 1.初识Docker 1.1.什么是Docker 微服务虽然具备各种各样的优势&#xff0c;但服务的拆分通用给部署带来了很大的麻烦。 分布式系统中&#xff0c;依赖的组件非常多&#xff0c;不同组件之间部署时往往会产生一些冲突。在数百上千台服务中重复部署…...

虚拟化服务器+华为防火墙+kiwi_syslog访问留痕

一、适用场景 1、大中型企业需要对接入用户的访问进行记录时&#xff0c;以前用3CDaemon时&#xff0c;只能用于小型网络当中&#xff0c;记录的数据量太大时&#xff0c;本例采用破解版的kiwi_syslog。 2、当网监、公安查到有非法访问时&#xff0c;可提供基于五元组的外网访…...

FlinkSQL聚合函数(Aggregate Function)详解

使用场景&#xff1a; 聚合函数即 UDAF&#xff0c;常⽤于进多条数据&#xff0c;出⼀条数据的场景。 上图展示了⼀个 聚合函数的例⼦ 以及 聚合函数包含的重要⽅法。 案例场景&#xff1a; 关于饮料的表&#xff0c;有三个字段&#xff0c;分别是 id、name、price&#xff0…...

TensorFlow学习笔记--(3)张量的常用运算函数

损失函数及求偏导 通过 tf.GradientTape 函数来指定损失函数的变量以及表达式 最后通过 gradient(%损失函数%,%偏导对象%) 来获取求偏导的结果 独热编码 给出一组特征值 来对图像进行分类 可以用独热编码 0的概率是第0种 1的概率是第1种 0的概率是第二种 tf.one_hot(%某标签…...

idea大量爆红问题解决

问题描述 在学习和工作中&#xff0c;idea是程序员不可缺少的一个工具&#xff0c;但是突然在有些时候就会出现大量爆红的问题&#xff0c;发现无法跳转&#xff0c;无论是关机重启或者是替换root都无法解决 就是如上所展示的问题&#xff0c;但是程序依然可以启动。 问题解决…...

React Native 开发环境搭建(全平台详解)

React Native 开发环境搭建&#xff08;全平台详解&#xff09; 在开始使用 React Native 开发移动应用之前&#xff0c;正确设置开发环境是至关重要的一步。本文将为你提供一份全面的指南&#xff0c;涵盖 macOS 和 Windows 平台的配置步骤&#xff0c;如何在 Android 和 iOS…...

《Playwright:微软的自动化测试工具详解》

Playwright 简介:声明内容来自网络&#xff0c;将内容拼接整理出来的文档 Playwright 是微软开发的自动化测试工具&#xff0c;支持 Chrome、Firefox、Safari 等主流浏览器&#xff0c;提供多语言 API&#xff08;Python、JavaScript、Java、.NET&#xff09;。它的特点包括&a…...

Go 语言接口详解

Go 语言接口详解 核心概念 接口定义 在 Go 语言中&#xff0c;接口是一种抽象类型&#xff0c;它定义了一组方法的集合&#xff1a; // 定义接口 type Shape interface {Area() float64Perimeter() float64 } 接口实现 Go 接口的实现是隐式的&#xff1a; // 矩形结构体…...

江苏艾立泰跨国资源接力:废料变黄金的绿色供应链革命

在华东塑料包装行业面临限塑令深度调整的背景下&#xff0c;江苏艾立泰以一场跨国资源接力的创新实践&#xff0c;重新定义了绿色供应链的边界。 跨国回收网络&#xff1a;废料变黄金的全球棋局 艾立泰在欧洲、东南亚建立再生塑料回收点&#xff0c;将海外废弃包装箱通过标准…...

高等数学(下)题型笔记(八)空间解析几何与向量代数

目录 0 前言 1 向量的点乘 1.1 基本公式 1.2 例题 2 向量的叉乘 2.1 基础知识 2.2 例题 3 空间平面方程 3.1 基础知识 3.2 例题 4 空间直线方程 4.1 基础知识 4.2 例题 5 旋转曲面及其方程 5.1 基础知识 5.2 例题 6 空间曲面的法线与切平面 6.1 基础知识 6.2…...

【2025年】解决Burpsuite抓不到https包的问题

环境&#xff1a;windows11 burpsuite:2025.5 在抓取https网站时&#xff0c;burpsuite抓取不到https数据包&#xff0c;只显示&#xff1a; 解决该问题只需如下三个步骤&#xff1a; 1、浏览器中访问 http://burp 2、下载 CA certificate 证书 3、在设置--隐私与安全--…...

【JavaSE】绘图与事件入门学习笔记

-Java绘图坐标体系 坐标体系-介绍 坐标原点位于左上角&#xff0c;以像素为单位。 在Java坐标系中,第一个是x坐标,表示当前位置为水平方向&#xff0c;距离坐标原点x个像素;第二个是y坐标&#xff0c;表示当前位置为垂直方向&#xff0c;距离坐标原点y个像素。 坐标体系-像素 …...

MySQL用户和授权

开放MySQL白名单 可以通过iptables-save命令确认对应客户端ip是否可以访问MySQL服务&#xff1a; test: # iptables-save | grep 3306 -A mp_srv_whitelist -s 172.16.14.102/32 -p tcp -m tcp --dport 3306 -j ACCEPT -A mp_srv_whitelist -s 172.16.4.16/32 -p tcp -m tcp -…...

基于matlab策略迭代和值迭代法的动态规划

经典的基于策略迭代和值迭代法的动态规划matlab代码&#xff0c;实现机器人的最优运输 Dynamic-Programming-master/Environment.pdf , 104724 Dynamic-Programming-master/README.md , 506 Dynamic-Programming-master/generalizedPolicyIteration.m , 1970 Dynamic-Programm…...