优化器与现有网络模型的修改
文章目录
- 一、优化器是什么
- 二、优化器的使用
- 三、分类模型VGG16
- 四、现有网络模型的修改
一、优化器是什么
优化器(Optimizer)是一个算法,用于在训练过程中调整模型的参数,以便最小化损失函数(Loss Function)。损失函数衡量的是模型预测值与真实值之间的差异,而优化器则负责通过更新模型的权重(Weights)和偏置(Biases)来减少这种差异。
利用得到的梯度,用优化器对梯度进行修正,从而得到整体误差降低的目的。
优化器Optimizer 所需要从参数:

参数解析:
- model.parameters()是训练的模型
- lr(LearningRate)是学习率,这是最核心的参数之一,它决定了在每次迭代中参数更新的步长。如果学习率太高,可能会导致训练过程中的梯度爆炸,使模型无法收敛,训练很不稳定;如果学习率太低,训练过程可能会变得非常缓慢。
推荐一开始用大的lr值进行运算,到后面用小的lr再进行运算。 - 动量(Momentum)往往是特定参数,是用于加速梯度下降方法,特别是在处理凸优化问题时。它通过在连续的迭代中累积梯度信息来帮助优化器克服局部最小值,并加快收敛速度。
二、优化器的使用
本文使用我的上一章内容神经网络内容进行续写,神经网络具体可跳转损失函数和反向传播
使用一下代码来进行梯度优化:
optim.zero_grad()# 向后传播result_loss.backward()#这一步对数值进行调优optim.step()
整体代码如下:
import torch
import torchvision
from torch import nn
from torch.nn import Conv2d, MaxPool2d, Flatten, Linear, Sequential
from torch.utils.data import DataLoader
from torch.utils.tensorboard import SummaryWriterdataset = torchvision.datasets.CIFAR10("../data", train=False, transform=torchvision.transforms.ToTensor(), download= True)dataloader = DataLoader (dataset, batch_size = 1)
class Sen(nn.Module):def __init__(self):super(Sen,self).__init__()self.model1 = Sequential(Conv2d(3, 32, 5, 1, 2),MaxPool2d(2),Conv2d(32, 32, 5, padding=2),MaxPool2d(2),Conv2d(32, 64, 5, padding=2),MaxPool2d(2),Flatten(),Linear(1024, 64),Linear(64, 10))def forward(self, x):x = self.model1(x)return xloss = nn.CrossEntropyLoss()
sen = Sen()#随机梯度下降
optim = torch.optim.SGD(sen.parameters(), lr=0.01)for data in dataloader:imgs, tatgets = dataoutputs = sen(imgs)result_loss = loss(outputs, tatgets)#对参数进行梯度清零optim.zero_grad()# 向后传播result_loss.backward()#这一步对数值进行调优optim.step()
在未运行时的梯度没有值:

当运行一下:

可以看到每个参数节点的值被计算出来了。
当for循环第二次运行的时候,可以看到grad梯度已经被优化了:

通过反复循环,上图中的data数据,也就是loss就会越来越被优化。
上面的for循环其实是为数据的一次小循环,我们可以加上epoch 外嵌套 进行数据的一轮轮循环深度优化:
for epoch in range(20):running_loss = 0.0#这里只是进行了一次的学习for data in dataloader:imgs, tatgets = dataoutputs = sen(imgs)result_loss = loss(outputs, tatgets)#对参数进行梯度清零optim.zero_grad()# 向后传播result_loss.backward()#这一步对数值进行调优aoptim.step()#这一步就相当于所有误差的一个整体求和running_loss = running_loss + result_loss
整体代码:
import torch
import torchvision
from torch import nn
from torch.nn import Conv2d, MaxPool2d, Flatten, Linear, Sequential
from torch.utils.data import DataLoader
from torch.utils.tensorboard import SummaryWriterdataset = torchvision.datasets.CIFAR10("../data", train=False, transform=torchvision.transforms.ToTensor(), download= True)dataloader = DataLoader (dataset, batch_size = 1)
class Sen(nn.Module):def __init__(self):super(Sen,self).__init__()self.model1 = Sequential(Conv2d(3, 32, 5, 1, 2),MaxPool2d(2),Conv2d(32, 32, 5, padding=2),MaxPool2d(2),Conv2d(32, 64, 5, padding=2),MaxPool2d(2),Flatten(),Linear(1024, 64),Linear(64, 10))def forward(self, x):x = self.model1(x)return xloss = nn.CrossEntropyLoss()
sen = Sen()#随机梯度下降
optim = torch.optim.SGD(sen.parameters(), lr=0.01)#这里是进行一轮一轮的学习
for epoch in range(20):running_loss = 0.0#这里只是进行了一次的学习for data in dataloader:imgs, tatgets = dataoutputs = sen(imgs)result_loss = loss(outputs, tatgets)#对参数进行梯度清零optim.zero_grad()# 向后传播result_loss.backward()#这一步对数值进行调优aoptim.step()#这一步就相当于所有误差的一个整体求和running_loss = running_loss + result_lossprint(running_loss)
运行结果如下,可以看到,整个神经网络在所有的数据当中,它的误差之和如下:

在第一轮优化的时候,整个神经网络的误差之和是18779
在第二轮优化的时候,整个神经网络的误差之和是16205
在第三轮优化的时候,整个神经网络的误差之和是15448
可以看到,通过优化器的一轮轮优化,整体的loss值会一直降低,从而达到数据优化的效果。
三、分类模型VGG16
pytorch为我们提供了很多网络模型,其中包括分类模型VGG16
分类模型VGG16是基于ImageNet数据集进行训练的,所以我们需要下载ImageNet数据集
由于ImageNet数据集的内存为143g,会发生以下报错,需要我们自己去下载ImageNet数据集再放在根目录当中。

既然ImageNet数据集太大,那么就换一条思路,用一下方法加载vgg16
import torchvision.datasets
vgg16_false = torchvision.models.vgg16(pretrained=False)
vgg16_True = torchvision.models.vgg16(pretrained=True)
print('ok')
如果pretrained = True,说明这个数据集已经是训练好的了。
如果pretrained = False,说明这些参数是一个初始参数,没有在任何参数集上面进行训练。
如果progress = True,显示下载进度条
如果progress = Flase,则不显示下载进度条
vgg16_false = torchvision.models.vgg16(pretrained=False),这代码表示只是加载网络模型(也就是像之前的网络模型那样,只是加载模型,含有卷积,池化等,其中的参数都是默认的),所以它不需要下载。
vgg16_True = torchvision.models.vgg16(pretrained=True),这代码表示需要把网络模型参数进行一个下载,还要加载对应的参数。故它需要进行下载。
简单理解就是False不需要进行下载,而True需要进行下载。
VGG16将数据集分成1000个类。
print(vgg16_true)
输出结果:


看它把各种卷积层,最大池化都自动按参数下载好了。
常用的CIFAR10会把数据集分成10个类。
vgg16会把数据集分成1000个类,如上图的out_features=1000
四、现有网络模型的修改
方法:像上面得到的是out_features=1000,我们可以进行一个新的处理,通过Linear将输入是1000,而输出为10,从而达到降类的效果。
vgg16_true.add_module("add_linear", nn.Linear(1000, 10))
运行得到:

可以看到,在add_linear这里的out_features=10
如果要想类的改变在classifier当中,那么代码只需要添加上classifier
vgg16_true.classifier.add_module("add_linear", nn.Linear(1000, 10))
运行结果:

整体代码如下:
import torchvision.datasets
from torch import nnvgg16_false = torchvision.models.vgg16(pretrained=False)
vgg16_true = torchvision.models.vgg16(pretrained=True)print(vgg16_true)train_data = torchvision.datasets.CIFAR10("./data",train=True, transform=torchvision.transforms.ToTensor(),download=True)vgg16_true.classifier.add_module("add_linear", nn.Linear(1000, 10))
如果想直接在上面 (6)Linear 里面修改out_features,而不是新命名一个(add_linear)进行修改也是可以的
用vgg16_flase进行示范:
在没进行修改前print(vgg16_false)
运行结果:

直接在(6)Linear中修改out_features为10
代码:
vgg16_false.classifier[6] = nn.Linear(4096, 10)
运行结果:

可以看到out_features=10,从而成功修改现有的网络模型。
相关文章:
优化器与现有网络模型的修改
文章目录 一、优化器是什么二、优化器的使用三、分类模型VGG16四、现有网络模型的修改 一、优化器是什么 优化器(Optimizer)是一个算法,用于在训练过程中调整模型的参数,以便最小化损失函数(Loss Function)…...
kafka 超详细的消息订阅与消息消费几种方式
kafka 消息订阅与消息消费几种方式 本文主要内容 消费者订阅几种方式 订阅多个主题 按正则表达式订阅 消息消费几种方式 按分区消费 按主题消费 不区分 “ 笔者建议一开始学习Kafka最好不要用SpringBoot 集成方式,因为SpringBoot推崇用注解方式,比如KafkaList…...
C++ 第三讲:内存管理
C 第三讲:内存管理 1.C内存分布2.内存管理方式2.1C语言内存管理方式2.2C内存管理方式2.2.1new\delete操作内置类型2.2.2new\delete操作自定义类型 3.operator new与operator delete函数4.new和delete实现原理4.1内置类型4.2自定义类型 5.定位new5.1内存池的基本了解…...
LeeCode打卡第二十九天
LeeCode打卡第二十九天 第一题:岛屿数量(LeeCode第200题): 给你一个由 1(陆地)和 0(水)组成的的二维网格,请你计算网格中岛屿的数量。岛屿总是被水包围,并且每座岛屿只…...
阿里云专业翻译api对接
最近我们一个商城项目涉及多语言切换,默认中文。用户切换语言可选英语和阿拉伯语言,前端APP和后端返回动态数据都要根据用户选择语言来展示。前端静态内容都做了三套语言,后端商品为了适用这种多语言我们也进行了改造。每一件商品名称&#x…...
基于Spring Boot的能源管理系统+建筑能耗+建筑能耗监测系统+节能监测系统+能耗监测+建筑能耗监测
介绍 建筑节能监测系统是基于计算机网络、物联网、大数据和数据可视化等多种技术融合形成的一套节能监测系统。 系统实现了对建筑电、水、热,气等能源、资源消耗情况的实时监测和预警、动态分析和评估,为用户建立了科学、系统的节能分析方法,…...
大数据新视界 --大数据大厂之 Cassandra 分布式数据库:高可用数据存储的新选择
💖💖💖亲爱的朋友们,热烈欢迎你们来到 青云交的博客!能与你们在此邂逅,我满心欢喜,深感无比荣幸。在这个瞬息万变的时代,我们每个人都在苦苦追寻一处能让心灵安然栖息的港湾。而 我的…...
ROS第五梯:ROS+VSCode+C++单步调试
解决问题:在ROS项目中进行断点调试。 第一步:创建一个ROS项目或者打开一个现有的ROS项目。 第二步:修改c_cpp_properties.json 增加一段命令: "compileCommands": "${workspaceFolder}/build/compile_commands.json"第三…...
SLA 概念和计算方法
SLA 概念和计算方法 SLA SLA:服务等级协议(简称:SLA,全称:service level agreement) 网站服务可用性的一个保证 9越多代表全年服务可用时间越长服务更可靠,停机时间越短,反之亦然…...
C++比大小游戏
目录 开头程序程序的流程图程序游玩的效果下一篇博客要说的东西 开头 大家好,我叫这是我58。 程序 #include <iostream> #include <Windows.h> using namespace std; int main() {int ir 1;char chparr[2] { 0 };int ip1 0;int ip2 0;int i 1;c…...
PCIe进阶之TL:Memory, I/O, and Configuration Request Rules TPH Rules
1 Memory, I/O, and Configuration Request Rules 下述规则适用于 Memory 请求、IO 请求和配置请求。 除了公共的 header 字段外,所有 Memory 请求、IO 请求和配置请求还包括以下字段: (1)Requester ID[15:0] 和 Tag[9:0],组成了 Transaction ID 。 (2)Last DW BE[3:0]…...
【初阶数据结构】一文讲清楚 “堆” 和 “堆排序” -- 树和二叉树(二)(内含TOP-K问题)
文章目录 前言1. 堆1.1 堆的概念1.2 堆的分类 2. 堆的实现2.1 堆的结构体设置2.2 堆的初始化2.3 堆的销毁2.4 添加数据到堆2.4.1 "向上调整"算法 2.5 从堆中删除数据2.5.1 “向下调整”算法 2.6 堆的其它各种方法接口函数 3. 堆排序3.1 堆排序的代码实现 4. TOP-K问题…...
sqli-lab靶场学习(二)——Less8-10(盲注、时间盲注)
Less8 第八关依然是先看一般状态 http://localhost/sqli-labs/Less-8/?id1 然后用单引号闭合: http://localhost/sqli-labs/Less-8/?id1 这关的问题在于报错是不显示,那没办法通过上篇文章的updatexml大法处理。对于这种情况,需要用“盲…...
Dijkstra算法和BFS算法(单源最短路径)
基于你设计的带权有向图,从某一结点出发,执行Dijkstra算法求单源最短路径。用文字描述每一轮执行的过程 文字描述:用BFS算法求单源最短路径的过程 Dijkstra 算法 BFS算法 广度优先算法...
在WordPress中最佳Elementor主题推荐:专家级指南
对于已经在WordPress和Elementor上有丰富经验的用户来说,选择功能强大且高度灵活的主题,能大大提升网站的表现和定制能力。今天,我们来介绍六款适合用户的专家级Elementor主题:Sydney、Blocksy、Rife Free、Customify、Deep和Laye…...
关于RabbitMQ消息丢失的解决方案
RabbitMQ如何保证消息的可靠性传输 一、消息丢失的原因 1. 生产者端 网络问题: 原因:生产者与RabbitMQ服务器之间的网络连接不稳定或中断,导致消息在传输过程中丢失。解决方案:确保网络连接稳定,监控网络状态&#x…...
c语言动态内存分配
前言 我们已经掌握的内存开辟⽅式有: int val 20;//在栈空间上开辟四个字节 char arr[10] {0};//在栈空间上开辟10个字节的连续空间 但是上述的开辟空间的⽅式有两个特点: • 空间开辟⼤⼩是固定的。 • 数组在申明的时候,必须指定数组的…...
零基础制作一个ST-LINK V2 附PCB文件原理图 AD格式
资料下载地址:零基础制作一个ST-LINK V2 附PCB文件原理图 AD格式 ST-LINK/V2是一款可以在线仿真以及下载STM8以及STM32的开发工具。支持所有带SWIM接口的STM8系列单片机;支持所有带JTAG / SWD接口的STM32系列单片机。 基本属性 ST-LINK/V2是ST意法半导体为评估、开…...
nginx基础篇(一)
文章目录 学习链接概图一、Nginx简介1.1 背景介绍名词解释 1.2 常见服务器对比IISTomcatApacheLighttpd其他的服务器 1.3 Nginx的优点(1)速度更快、并发更高(2)配置简单,扩展性强(3)高可靠性(4)热部署(5)成本低、BSD许可证 1.4 Nginx的功能特性及常用功能基本HTTP服…...
监控系列之-Grafana面板展示及制作
一 Grafana设置添加数据源 1、设置Grafana中文显示 最后保存退出,数据源添加完毕 2、导入node_exporter主机监控面板 此处 有外网的情况下,直接输入对应面板的ID号,然后点击加载即可;无无外网的话,则考虑使用上传仪表…...
变量 varablie 声明- Rust 变量 let mut 声明与 C/C++ 变量声明对比分析
一、变量声明设计:let 与 mut 的哲学解析 Rust 采用 let 声明变量并通过 mut 显式标记可变性,这种设计体现了语言的核心哲学。以下是深度解析: 1.1 设计理念剖析 安全优先原则:默认不可变强制开发者明确声明意图 let x 5; …...
生成xcframework
打包 XCFramework 的方法 XCFramework 是苹果推出的一种多平台二进制分发格式,可以包含多个架构和平台的代码。打包 XCFramework 通常用于分发库或框架。 使用 Xcode 命令行工具打包 通过 xcodebuild 命令可以打包 XCFramework。确保项目已经配置好需要支持的平台…...
进程地址空间(比特课总结)
一、进程地址空间 1. 环境变量 1 )⽤户级环境变量与系统级环境变量 全局属性:环境变量具有全局属性,会被⼦进程继承。例如当bash启动⼦进程时,环 境变量会⾃动传递给⼦进程。 本地变量限制:本地变量只在当前进程(ba…...
ffmpeg(四):滤镜命令
FFmpeg 的滤镜命令是用于音视频处理中的强大工具,可以完成剪裁、缩放、加水印、调色、合成、旋转、模糊、叠加字幕等复杂的操作。其核心语法格式一般如下: ffmpeg -i input.mp4 -vf "滤镜参数" output.mp4或者带音频滤镜: ffmpeg…...
iOS性能调优实战:借助克魔(KeyMob)与常用工具深度洞察App瓶颈
在日常iOS开发过程中,性能问题往往是最令人头疼的一类Bug。尤其是在App上线前的压测阶段或是处理用户反馈的高发期,开发者往往需要面对卡顿、崩溃、能耗异常、日志混乱等一系列问题。这些问题表面上看似偶发,但背后往往隐藏着系统资源调度不当…...
STM32HAL库USART源代码解析及应用
STM32HAL库USART源代码解析 前言STM32CubeIDE配置串口USART和UART的选择使用模式参数设置GPIO配置DMA配置中断配置硬件流控制使能生成代码解析和使用方法串口初始化__UART_HandleTypeDef结构体浅析HAL库代码实际使用方法使用轮询方式发送使用轮询方式接收使用中断方式发送使用中…...
MacOS下Homebrew国内镜像加速指南(2025最新国内镜像加速)
macos brew国内镜像加速方法 brew install 加速formula.jws.json下载慢加速 🍺 最新版brew安装慢到怀疑人生?别怕,教你轻松起飞! 最近Homebrew更新至最新版,每次执行 brew 命令时都会自动从官方地址 https://formulae.…...
Leetcode33( 搜索旋转排序数组)
题目表述 整数数组 nums 按升序排列,数组中的值 互不相同 。 在传递给函数之前,nums 在预先未知的某个下标 k(0 < k < nums.length)上进行了 旋转,使数组变为 [nums[k], nums[k1], …, nums[n-1], nums[0], nu…...
沙箱虚拟化技术虚拟机容器之间的关系详解
问题 沙箱、虚拟化、容器三者分开一一介绍的话我知道他们各自都是什么东西,但是如果把三者放在一起,它们之间到底什么关系?又有什么联系呢?我不是很明白!!! 就比如说: 沙箱&#…...
如何配置一个sql server使得其它用户可以通过excel odbc获取数据
要让其他用户通过 Excel 使用 ODBC 连接到 SQL Server 获取数据,你需要完成以下配置步骤: ✅ 一、在 SQL Server 端配置(服务器设置) 1. 启用 TCP/IP 协议 打开 “SQL Server 配置管理器”。导航到:SQL Server 网络配…...
