优化器与现有网络模型的修改
文章目录
- 一、优化器是什么
- 二、优化器的使用
- 三、分类模型VGG16
- 四、现有网络模型的修改
一、优化器是什么
优化器(Optimizer)是一个算法,用于在训练过程中调整模型的参数,以便最小化损失函数(Loss Function)。损失函数衡量的是模型预测值与真实值之间的差异,而优化器则负责通过更新模型的权重(Weights)和偏置(Biases)来减少这种差异。
利用得到的梯度,用优化器对梯度进行修正,从而得到整体误差降低的目的。
优化器Optimizer 所需要从参数:

参数解析:
- model.parameters()是训练的模型
- lr(LearningRate)是学习率,这是最核心的参数之一,它决定了在每次迭代中参数更新的步长。如果学习率太高,可能会导致训练过程中的梯度爆炸,使模型无法收敛,训练很不稳定;如果学习率太低,训练过程可能会变得非常缓慢。
推荐一开始用大的lr值进行运算,到后面用小的lr再进行运算。 - 动量(Momentum)往往是特定参数,是用于加速梯度下降方法,特别是在处理凸优化问题时。它通过在连续的迭代中累积梯度信息来帮助优化器克服局部最小值,并加快收敛速度。
二、优化器的使用
本文使用我的上一章内容神经网络内容进行续写,神经网络具体可跳转损失函数和反向传播
使用一下代码来进行梯度优化:
optim.zero_grad()# 向后传播result_loss.backward()#这一步对数值进行调优optim.step()
整体代码如下:
import torch
import torchvision
from torch import nn
from torch.nn import Conv2d, MaxPool2d, Flatten, Linear, Sequential
from torch.utils.data import DataLoader
from torch.utils.tensorboard import SummaryWriterdataset = torchvision.datasets.CIFAR10("../data", train=False, transform=torchvision.transforms.ToTensor(), download= True)dataloader = DataLoader (dataset, batch_size = 1)
class Sen(nn.Module):def __init__(self):super(Sen,self).__init__()self.model1 = Sequential(Conv2d(3, 32, 5, 1, 2),MaxPool2d(2),Conv2d(32, 32, 5, padding=2),MaxPool2d(2),Conv2d(32, 64, 5, padding=2),MaxPool2d(2),Flatten(),Linear(1024, 64),Linear(64, 10))def forward(self, x):x = self.model1(x)return xloss = nn.CrossEntropyLoss()
sen = Sen()#随机梯度下降
optim = torch.optim.SGD(sen.parameters(), lr=0.01)for data in dataloader:imgs, tatgets = dataoutputs = sen(imgs)result_loss = loss(outputs, tatgets)#对参数进行梯度清零optim.zero_grad()# 向后传播result_loss.backward()#这一步对数值进行调优optim.step()
在未运行时的梯度没有值:

当运行一下:

可以看到每个参数节点的值被计算出来了。
当for循环第二次运行的时候,可以看到grad梯度已经被优化了:

通过反复循环,上图中的data数据,也就是loss就会越来越被优化。
上面的for循环其实是为数据的一次小循环,我们可以加上epoch 外嵌套 进行数据的一轮轮循环深度优化:
for epoch in range(20):running_loss = 0.0#这里只是进行了一次的学习for data in dataloader:imgs, tatgets = dataoutputs = sen(imgs)result_loss = loss(outputs, tatgets)#对参数进行梯度清零optim.zero_grad()# 向后传播result_loss.backward()#这一步对数值进行调优aoptim.step()#这一步就相当于所有误差的一个整体求和running_loss = running_loss + result_loss
整体代码:
import torch
import torchvision
from torch import nn
from torch.nn import Conv2d, MaxPool2d, Flatten, Linear, Sequential
from torch.utils.data import DataLoader
from torch.utils.tensorboard import SummaryWriterdataset = torchvision.datasets.CIFAR10("../data", train=False, transform=torchvision.transforms.ToTensor(), download= True)dataloader = DataLoader (dataset, batch_size = 1)
class Sen(nn.Module):def __init__(self):super(Sen,self).__init__()self.model1 = Sequential(Conv2d(3, 32, 5, 1, 2),MaxPool2d(2),Conv2d(32, 32, 5, padding=2),MaxPool2d(2),Conv2d(32, 64, 5, padding=2),MaxPool2d(2),Flatten(),Linear(1024, 64),Linear(64, 10))def forward(self, x):x = self.model1(x)return xloss = nn.CrossEntropyLoss()
sen = Sen()#随机梯度下降
optim = torch.optim.SGD(sen.parameters(), lr=0.01)#这里是进行一轮一轮的学习
for epoch in range(20):running_loss = 0.0#这里只是进行了一次的学习for data in dataloader:imgs, tatgets = dataoutputs = sen(imgs)result_loss = loss(outputs, tatgets)#对参数进行梯度清零optim.zero_grad()# 向后传播result_loss.backward()#这一步对数值进行调优aoptim.step()#这一步就相当于所有误差的一个整体求和running_loss = running_loss + result_lossprint(running_loss)
运行结果如下,可以看到,整个神经网络在所有的数据当中,它的误差之和如下:

在第一轮优化的时候,整个神经网络的误差之和是18779
在第二轮优化的时候,整个神经网络的误差之和是16205
在第三轮优化的时候,整个神经网络的误差之和是15448
可以看到,通过优化器的一轮轮优化,整体的loss值会一直降低,从而达到数据优化的效果。
三、分类模型VGG16
pytorch为我们提供了很多网络模型,其中包括分类模型VGG16
分类模型VGG16是基于ImageNet数据集进行训练的,所以我们需要下载ImageNet数据集
由于ImageNet数据集的内存为143g,会发生以下报错,需要我们自己去下载ImageNet数据集再放在根目录当中。

既然ImageNet数据集太大,那么就换一条思路,用一下方法加载vgg16
import torchvision.datasets
vgg16_false = torchvision.models.vgg16(pretrained=False)
vgg16_True = torchvision.models.vgg16(pretrained=True)
print('ok')
如果pretrained = True,说明这个数据集已经是训练好的了。
如果pretrained = False,说明这些参数是一个初始参数,没有在任何参数集上面进行训练。
如果progress = True,显示下载进度条
如果progress = Flase,则不显示下载进度条
vgg16_false = torchvision.models.vgg16(pretrained=False),这代码表示只是加载网络模型(也就是像之前的网络模型那样,只是加载模型,含有卷积,池化等,其中的参数都是默认的),所以它不需要下载。
vgg16_True = torchvision.models.vgg16(pretrained=True),这代码表示需要把网络模型参数进行一个下载,还要加载对应的参数。故它需要进行下载。
简单理解就是False不需要进行下载,而True需要进行下载。
VGG16将数据集分成1000个类。
print(vgg16_true)
输出结果:


看它把各种卷积层,最大池化都自动按参数下载好了。
常用的CIFAR10会把数据集分成10个类。
vgg16会把数据集分成1000个类,如上图的out_features=1000
四、现有网络模型的修改
方法:像上面得到的是out_features=1000,我们可以进行一个新的处理,通过Linear将输入是1000,而输出为10,从而达到降类的效果。
vgg16_true.add_module("add_linear", nn.Linear(1000, 10))
运行得到:

可以看到,在add_linear这里的out_features=10
如果要想类的改变在classifier当中,那么代码只需要添加上classifier
vgg16_true.classifier.add_module("add_linear", nn.Linear(1000, 10))
运行结果:

整体代码如下:
import torchvision.datasets
from torch import nnvgg16_false = torchvision.models.vgg16(pretrained=False)
vgg16_true = torchvision.models.vgg16(pretrained=True)print(vgg16_true)train_data = torchvision.datasets.CIFAR10("./data",train=True, transform=torchvision.transforms.ToTensor(),download=True)vgg16_true.classifier.add_module("add_linear", nn.Linear(1000, 10))
如果想直接在上面 (6)Linear 里面修改out_features,而不是新命名一个(add_linear)进行修改也是可以的
用vgg16_flase进行示范:
在没进行修改前print(vgg16_false)
运行结果:

直接在(6)Linear中修改out_features为10
代码:
vgg16_false.classifier[6] = nn.Linear(4096, 10)
运行结果:

可以看到out_features=10,从而成功修改现有的网络模型。
相关文章:
优化器与现有网络模型的修改
文章目录 一、优化器是什么二、优化器的使用三、分类模型VGG16四、现有网络模型的修改 一、优化器是什么 优化器(Optimizer)是一个算法,用于在训练过程中调整模型的参数,以便最小化损失函数(Loss Function)…...
kafka 超详细的消息订阅与消息消费几种方式
kafka 消息订阅与消息消费几种方式 本文主要内容 消费者订阅几种方式 订阅多个主题 按正则表达式订阅 消息消费几种方式 按分区消费 按主题消费 不区分 “ 笔者建议一开始学习Kafka最好不要用SpringBoot 集成方式,因为SpringBoot推崇用注解方式,比如KafkaList…...
C++ 第三讲:内存管理
C 第三讲:内存管理 1.C内存分布2.内存管理方式2.1C语言内存管理方式2.2C内存管理方式2.2.1new\delete操作内置类型2.2.2new\delete操作自定义类型 3.operator new与operator delete函数4.new和delete实现原理4.1内置类型4.2自定义类型 5.定位new5.1内存池的基本了解…...
LeeCode打卡第二十九天
LeeCode打卡第二十九天 第一题:岛屿数量(LeeCode第200题): 给你一个由 1(陆地)和 0(水)组成的的二维网格,请你计算网格中岛屿的数量。岛屿总是被水包围,并且每座岛屿只…...
阿里云专业翻译api对接
最近我们一个商城项目涉及多语言切换,默认中文。用户切换语言可选英语和阿拉伯语言,前端APP和后端返回动态数据都要根据用户选择语言来展示。前端静态内容都做了三套语言,后端商品为了适用这种多语言我们也进行了改造。每一件商品名称&#x…...
基于Spring Boot的能源管理系统+建筑能耗+建筑能耗监测系统+节能监测系统+能耗监测+建筑能耗监测
介绍 建筑节能监测系统是基于计算机网络、物联网、大数据和数据可视化等多种技术融合形成的一套节能监测系统。 系统实现了对建筑电、水、热,气等能源、资源消耗情况的实时监测和预警、动态分析和评估,为用户建立了科学、系统的节能分析方法,…...
大数据新视界 --大数据大厂之 Cassandra 分布式数据库:高可用数据存储的新选择
💖💖💖亲爱的朋友们,热烈欢迎你们来到 青云交的博客!能与你们在此邂逅,我满心欢喜,深感无比荣幸。在这个瞬息万变的时代,我们每个人都在苦苦追寻一处能让心灵安然栖息的港湾。而 我的…...
ROS第五梯:ROS+VSCode+C++单步调试
解决问题:在ROS项目中进行断点调试。 第一步:创建一个ROS项目或者打开一个现有的ROS项目。 第二步:修改c_cpp_properties.json 增加一段命令: "compileCommands": "${workspaceFolder}/build/compile_commands.json"第三…...
SLA 概念和计算方法
SLA 概念和计算方法 SLA SLA:服务等级协议(简称:SLA,全称:service level agreement) 网站服务可用性的一个保证 9越多代表全年服务可用时间越长服务更可靠,停机时间越短,反之亦然…...
C++比大小游戏
目录 开头程序程序的流程图程序游玩的效果下一篇博客要说的东西 开头 大家好,我叫这是我58。 程序 #include <iostream> #include <Windows.h> using namespace std; int main() {int ir 1;char chparr[2] { 0 };int ip1 0;int ip2 0;int i 1;c…...
PCIe进阶之TL:Memory, I/O, and Configuration Request Rules TPH Rules
1 Memory, I/O, and Configuration Request Rules 下述规则适用于 Memory 请求、IO 请求和配置请求。 除了公共的 header 字段外,所有 Memory 请求、IO 请求和配置请求还包括以下字段: (1)Requester ID[15:0] 和 Tag[9:0],组成了 Transaction ID 。 (2)Last DW BE[3:0]…...
【初阶数据结构】一文讲清楚 “堆” 和 “堆排序” -- 树和二叉树(二)(内含TOP-K问题)
文章目录 前言1. 堆1.1 堆的概念1.2 堆的分类 2. 堆的实现2.1 堆的结构体设置2.2 堆的初始化2.3 堆的销毁2.4 添加数据到堆2.4.1 "向上调整"算法 2.5 从堆中删除数据2.5.1 “向下调整”算法 2.6 堆的其它各种方法接口函数 3. 堆排序3.1 堆排序的代码实现 4. TOP-K问题…...
sqli-lab靶场学习(二)——Less8-10(盲注、时间盲注)
Less8 第八关依然是先看一般状态 http://localhost/sqli-labs/Less-8/?id1 然后用单引号闭合: http://localhost/sqli-labs/Less-8/?id1 这关的问题在于报错是不显示,那没办法通过上篇文章的updatexml大法处理。对于这种情况,需要用“盲…...
Dijkstra算法和BFS算法(单源最短路径)
基于你设计的带权有向图,从某一结点出发,执行Dijkstra算法求单源最短路径。用文字描述每一轮执行的过程 文字描述:用BFS算法求单源最短路径的过程 Dijkstra 算法 BFS算法 广度优先算法...
在WordPress中最佳Elementor主题推荐:专家级指南
对于已经在WordPress和Elementor上有丰富经验的用户来说,选择功能强大且高度灵活的主题,能大大提升网站的表现和定制能力。今天,我们来介绍六款适合用户的专家级Elementor主题:Sydney、Blocksy、Rife Free、Customify、Deep和Laye…...
关于RabbitMQ消息丢失的解决方案
RabbitMQ如何保证消息的可靠性传输 一、消息丢失的原因 1. 生产者端 网络问题: 原因:生产者与RabbitMQ服务器之间的网络连接不稳定或中断,导致消息在传输过程中丢失。解决方案:确保网络连接稳定,监控网络状态&#x…...
c语言动态内存分配
前言 我们已经掌握的内存开辟⽅式有: int val 20;//在栈空间上开辟四个字节 char arr[10] {0};//在栈空间上开辟10个字节的连续空间 但是上述的开辟空间的⽅式有两个特点: • 空间开辟⼤⼩是固定的。 • 数组在申明的时候,必须指定数组的…...
零基础制作一个ST-LINK V2 附PCB文件原理图 AD格式
资料下载地址:零基础制作一个ST-LINK V2 附PCB文件原理图 AD格式 ST-LINK/V2是一款可以在线仿真以及下载STM8以及STM32的开发工具。支持所有带SWIM接口的STM8系列单片机;支持所有带JTAG / SWD接口的STM32系列单片机。 基本属性 ST-LINK/V2是ST意法半导体为评估、开…...
nginx基础篇(一)
文章目录 学习链接概图一、Nginx简介1.1 背景介绍名词解释 1.2 常见服务器对比IISTomcatApacheLighttpd其他的服务器 1.3 Nginx的优点(1)速度更快、并发更高(2)配置简单,扩展性强(3)高可靠性(4)热部署(5)成本低、BSD许可证 1.4 Nginx的功能特性及常用功能基本HTTP服…...
监控系列之-Grafana面板展示及制作
一 Grafana设置添加数据源 1、设置Grafana中文显示 最后保存退出,数据源添加完毕 2、导入node_exporter主机监控面板 此处 有外网的情况下,直接输入对应面板的ID号,然后点击加载即可;无无外网的话,则考虑使用上传仪表…...
基于算法竞赛的c++编程(28)结构体的进阶应用
结构体的嵌套与复杂数据组织 在C中,结构体可以嵌套使用,形成更复杂的数据结构。例如,可以通过嵌套结构体描述多层级数据关系: struct Address {string city;string street;int zipCode; };struct Employee {string name;int id;…...
生成xcframework
打包 XCFramework 的方法 XCFramework 是苹果推出的一种多平台二进制分发格式,可以包含多个架构和平台的代码。打包 XCFramework 通常用于分发库或框架。 使用 Xcode 命令行工具打包 通过 xcodebuild 命令可以打包 XCFramework。确保项目已经配置好需要支持的平台…...
【力扣数据库知识手册笔记】索引
索引 索引的优缺点 优点1. 通过创建唯一性索引,可以保证数据库表中每一行数据的唯一性。2. 可以加快数据的检索速度(创建索引的主要原因)。3. 可以加速表和表之间的连接,实现数据的参考完整性。4. 可以在查询过程中,…...
sqlserver 根据指定字符 解析拼接字符串
DECLARE LotNo NVARCHAR(50)A,B,C DECLARE xml XML ( SELECT <x> REPLACE(LotNo, ,, </x><x>) </x> ) DECLARE ErrorCode NVARCHAR(50) -- 提取 XML 中的值 SELECT value x.value(., VARCHAR(MAX))…...
MySQL 8.0 OCP 英文题库解析(十三)
Oracle 为庆祝 MySQL 30 周年,截止到 2025.07.31 之前。所有人均可以免费考取原价245美元的MySQL OCP 认证。 从今天开始,将英文题库免费公布出来,并进行解析,帮助大家在一个月之内轻松通过OCP认证。 本期公布试题111~120 试题1…...
智能仓储的未来:自动化、AI与数据分析如何重塑物流中心
当仓库学会“思考”,物流的终极形态正在诞生 想象这样的场景: 凌晨3点,某物流中心灯火通明却空无一人。AGV机器人集群根据实时订单动态规划路径;AI视觉系统在0.1秒内扫描包裹信息;数字孪生平台正模拟次日峰值流量压力…...
css3笔记 (1) 自用
outline: none 用于移除元素获得焦点时默认的轮廓线 broder:0 用于移除边框 font-size:0 用于设置字体不显示 list-style: none 消除<li> 标签默认样式 margin: xx auto 版心居中 width:100% 通栏 vertical-align 作用于行内元素 / 表格单元格ÿ…...
【JavaWeb】Docker项目部署
引言 之前学习了Linux操作系统的常见命令,在Linux上安装软件,以及如何在Linux上部署一个单体项目,大多数同学都会有相同的感受,那就是麻烦。 核心体现在三点: 命令太多了,记不住 软件安装包名字复杂&…...
OPenCV CUDA模块图像处理-----对图像执行 均值漂移滤波(Mean Shift Filtering)函数meanShiftFiltering()
操作系统:ubuntu22.04 OpenCV版本:OpenCV4.9 IDE:Visual Studio Code 编程语言:C11 算法描述 在 GPU 上对图像执行 均值漂移滤波(Mean Shift Filtering),用于图像分割或平滑处理。 该函数将输入图像中的…...
Mobile ALOHA全身模仿学习
一、题目 Mobile ALOHA:通过低成本全身远程操作学习双手移动操作 传统模仿学习(Imitation Learning)缺点:聚焦与桌面操作,缺乏通用任务所需的移动性和灵活性 本论文优点:(1)在ALOHA…...
