Pytorch微调深度学习模型
在公开数据训练了模型,有时候需要拿到自己的数据上微调。今天正好做了一下微调,在此记录一下微调的方法。用Pytorch还是比较容易实现的。
网上找了很多方法,以及Chatgpt也给了很多方法,但是不够简洁和容易理解。
大体步骤是:
1、加载训练好的模型。
2、冻结不想微调的层,设置想训练的层。(这里可以新建一个层替换原有层,也可以不新建层,直接微调原有层)
3、训练即可。
1、先加载一个模型
我这里是训练好的一个SqueezeNet模型,所有模型都适用。
## 加载要微调的模型
# 环境里必须有模型的框架,才能torch.load
from Model.main_SqueezeNet import SqueezeNet,Firemodel = torch.load("Model/SqueezeNet.pth").to(device)
print(model)
# 输出结果
SqueezeNet((stem): Sequential((0): Conv2d(1, 8, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))(1): BatchNorm2d(8, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)(2): ReLU(inplace=True)(3): MaxPool2d(kernel_size=2, stride=2, padding=0, dilation=1, ceil_mode=False))(fire2): Fire((squeeze): Sequential((0): Conv2d(8, 4, kernel_size=(1, 1), stride=(1, 1))(1): BatchNorm2d(4, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)(2): ReLU(inplace=True))(expand_1x1): Sequential((0): Conv2d(4, 8, kernel_size=(1, 1), stride=(1, 1))(1): BatchNorm2d(8, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)(2): ReLU(inplace=True))(expand_3x3): Sequential((0): Conv2d(4, 8, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))(1): BatchNorm2d(8, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)(2): ReLU(inplace=True)))(fire3): Fire((squeeze): Sequential((0): Conv2d(16, 8, kernel_size=(1, 1), stride=(1, 1))(1): BatchNorm2d(8, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)(2): ReLU(inplace=True))(expand_1x1): Sequential((0): Conv2d(8, 8, kernel_size=(1, 1), stride=(1, 1))(1): BatchNorm2d(8, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)(2): ReLU(inplace=True))(expand_3x3): Sequential((0): Conv2d(8, 8, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))(1): BatchNorm2d(8, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)(2): ReLU(inplace=True)))(fire4): Fire((squeeze): Sequential((0): Conv2d(16, 8, kernel_size=(1, 1), stride=(1, 1))(1): BatchNorm2d(8, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)(2): ReLU(inplace=True))(expand_1x1): Sequential((0): Conv2d(8, 8, kernel_size=(1, 1), stride=(1, 1))(1): BatchNorm2d(8, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)(2): ReLU(inplace=True))(expand_3x3): Sequential((0): Conv2d(8, 8, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))(1): BatchNorm2d(8, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)(2): ReLU(inplace=True)))(conv10): Conv2d(16, 2, kernel_size=(1, 1), stride=(1, 1))(avg): AdaptiveAvgPool2d(output_size=1)(maxpool): MaxPool2d(kernel_size=2, stride=2, padding=0, dilation=1, ceil_mode=False)
)
print(model)时会显示模型每个层的名字。这里我想对conv10层进行微调,因为它是最后一个具有参数可以微调的层了。当然,如果最后一层是全连接的话,也建议微调最后全连接层。
2、冻结不想训练的层。
这里就有两种不同的方法了:一是新建一个conv10层,替换掉原来的层。二是不新建,直接微调原来的层。
新建:
model.conv10 = nn.Conv2d(model.conv10.in_channels, model.conv10.out_channels, model.conv10.kernel_size, model.conv10.stride)
print(model)
可以直接用model.conv10.in_channels等加载原来层的各种参数。这样就定义好了一个新的conv10层,并且已经替换进了模型中。
然后先冻结所有层(requires_grad = False),再放开conv10层(requires_grad = True)。
# 先冻结所有层
for param in model.parameters():param.requires_grad = False# 仅对conv10层进行微调,如果在冻结后新定义了conv10层,这两行可以不写,默认有梯度
for param in model.conv10.parameters():param.requires_grad = True
如果不新建层,则不需要运行model.conv10 = nn.Conv2d那一行即可。直接开始冻结就可以。
3、训练
这里一定要注意,optimizer里要设置参数 model.conv10.parameters(),而不是model.parameters()。这是让模型知道它将要训练哪些参数。
optimizer = optim.SGD(model.conv10.parameters(), lr=1e-2)
虽然上面已经冻结了不想训练的参数,但是这里最好还是写上model.conv10.parameters()。大家也可以试试不写行不行。
# 使用交叉熵损失函数
criterion = nn.CrossEntropyLoss()
# 只优化conv10层的参数
optimizer = optim.SGD(model.conv10.parameters(), lr=1e-2)
# 将模型移到GPU(如果可用)
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
model.to(device)# 设置模型为训练模式
model.train()num_epochs = 10
for epoch in range(num_epochs):# model.train()running_loss = 0.0correct = 0for x_train, y_train in data_loader:x_train, y_train = x_train.to(device), y_train.to(device)print(x_train.shape, y_train.shape)# 前向传播outputs = model(x_train)loss = criterion(outputs, y_train)# 反向传播和优化optimizer.zero_grad()loss.backward()optimizer.step()running_loss += loss.item() * x_train.size(0)# 统计训练集的准确率_, predicted = torch.max(outputs, 1)correct += (predicted == y_train).sum().item()# 计算每个 epoch 的训练损失和准确率epoch_loss = running_loss / len(dataset)epoch_accuracy = 100 * correct / len(dataset)# if epoch % 5 == 0 or epoch == num_epochs-1 :print(f'Epoch [{epoch+1}/{num_epochs}]')print(f'Train Loss: {epoch_loss:.4f}, Train Accuracy: {epoch_accuracy:.2f}%')
输出显示Loss下降说明模型有在学习。 模型准确率从0变成100,还是非常有成就感的!当然我这里就用了一个样本来微调hhhh。
Epoch [1/10]
Train Loss: 0.8185, Train Accuracy: 0.00%
torch.Size([1, 1, 32, 16]) torch.Size([1])
Epoch [2/10]
Train Loss: 0.7063, Train Accuracy: 0.00%
torch.Size([1, 1, 32, 16]) torch.Size([1])
Epoch [3/10]
Train Loss: 0.6141, Train Accuracy: 100.00%
torch.Size([1, 1, 32, 16]) torch.Size([1])
Epoch [4/10]
Train Loss: 0.5385, Train Accuracy: 100.00%
torch.Size([1, 1, 32, 16]) torch.Size([1])
Epoch [5/10]
Train Loss: 0.4761, Train Accuracy: 100.00%
torch.Size([1, 1, 32, 16]) torch.Size([1])
Epoch [6/10]
Train Loss: 0.4244, Train Accuracy: 100.00%
torch.Size([1, 1, 32, 16]) torch.Size([1])
Epoch [7/10]
Train Loss: 0.3812, Train Accuracy: 100.00%
torch.Size([1, 1, 32, 16]) torch.Size([1])
Epoch [8/10]
Train Loss: 0.3449, Train Accuracy: 100.00%
torch.Size([1, 1, 32, 16]) torch.Size([1])
Epoch [9/10]
Train Loss: 0.3140, Train Accuracy: 100.00%
torch.Size([1, 1, 32, 16]) torch.Size([1])
Epoch [10/10]
Train Loss: 0.2876, Train Accuracy: 100.00%
4、验证一下确实是只有这个层参数变化了,而其他层参数没变。
在训练模型之前,看一下这个层的参数:
raw_parm = model.conv10.weight
print(raw_parm)
# 部分输出为
Parameter containing:
tensor([[[[-0.1621]],[[ 0.0288]],[[ 0.1275]],[[ 0.1584]],[[ 0.0248]],[[-0.2013]],[[-0.2086]],[[ 0.1460]],[[ 0.0566]],[[ 0.2897]],[[ 0.2898]],[[ 0.0610]],[[ 0.2172]],[[ 0.0860]],[[ 0.2730]],[[-0.1053]]],
训练后,也输出一下这个层的参数:
## 查看微调后模型的参数
tuned_parm = model.conv10.weight
print(tuned_parm)
# 部分输出为:
Parameter containing:
tensor([[[[-0.1446]],[[ 0.0365]],[[ 0.1490]],[[ 0.1783]],[[ 0.0424]],[[-0.1826]],[[-0.1903]],[[ 0.1636]],[[ 0.0755]],[[ 0.3092]],[[ 0.3093]],[[ 0.0833]],[[ 0.2405]],[[ 0.1049]],[[ 0.2925]],[[-0.0866]]],
可见这个层的参数确实是变了。
然后检查一下别的随便一个层:
训练前:
# 训练前
raw_parm = model.stem[0].weight
print(raw_parm)
# 部分输出为:
Parameter containing:
tensor([[[[-0.0723, -0.2151, 0.1123],[-0.2114, 0.0173, -0.1322],[-0.0819, 0.0748, -0.2790]]],[[[-0.0918, -0.2783, -0.3193],[ 0.0359, 0.2993, -0.3422],[ 0.1979, 0.2499, -0.0528]]],
训练后:
## 查看微调后模型的参数
tuned_parm = model.stem[0].weight
print(tuned_parm)
# 部分输出为:
Parameter containing:
tensor([[[[-0.0723, -0.2151, 0.1123],[-0.2114, 0.0173, -0.1322],[-0.0819, 0.0748, -0.2790]]],[[[-0.0918, -0.2783, -0.3193],[ 0.0359, 0.2993, -0.3422],[ 0.1979, 0.2499, -0.0528]]],
可见参数没有变化。说明这层没有进行学习。
5、为了让大家更容易全面理解,完整代码如下。
import torch
import numpy as np
import torch.optim as optim
import torch.nn as nn
from torchinfo import summary
from torch.utils.data import DataLoader, Dataset,TensorDataset
from sklearn.metrics import confusion_matrix, ConfusionMatrixDisplay, precision_score, recall_score, f1_score
import matplotlib.pyplot as plt
from imblearn.under_sampling import RandomUnderSampler # 多数样本下采样device = torch.device("cuda" if torch.cuda.is_available() else "cpu")## 加载微调数据
feats = np.load("feats_jn105.npy")
labels = np.array([0])
print(feats.shape)
print(labels.shape)# 将data和labels转换为 PyTorch 张量
data_tensor = torch.tensor(feats, dtype = torch.float32, requires_grad=True)
labels_tensor = torch.tensor(labels, dtype = torch.long)# 添加通道维度
# data_tensor = data_tensor.unsqueeze(1) # 变为(num, 1, 32, 16)
batch_size = 15# 创建 TensorDataset
dataset = TensorDataset(data_tensor, labels_tensor)
data_loader = DataLoader(dataset, batch_size = batch_size, shuffle = False)
input, label = next(iter(data_loader))
print(input.shape,label.shape)
# upyter nbconvert --to script ./Model/main_SqueezeNet.ipynb # 终端运行,ipynb转py## 加载要微调的模型
# 环境里必须有模型的框架,才能torch.load
from Model.main_SqueezeNet import SqueezeNet,Firemodel = torch.load("Model/SqueezeNet.pth").to(device)
print(model)# 为模型写一个新的层
# model.fc = nn.Linear(in_features = model.fc.in_features, out_features = model.fc.out_features)
model.conv10 = nn.Conv2d(model.conv10.in_channels, model.conv10.out_channels, model.conv10.kernel_size, model.conv10.stride)
print(model)# 先冻结所有层
for param in model.parameters():param.requires_grad = False# 仅对conv10层进行微调,如果在冻结后新定义了conv10层,这两行可以不写,默认有梯度
for param in model.conv10.parameters():param.requires_grad = Trueraw_parm = model.stem[0].weight
print(raw_parm)
for name, param in model.named_parameters():print(name, param.requires_grad)# 使用交叉熵损失函数
criterion = nn.CrossEntropyLoss()# 只优化c10层的参数
optimizer = optim.SGD(model.conv10.parameters(), lr=1e-2)# 将模型移到GPU(如果可用)
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
model.to(device)# 设置模型为训练模式
model.train()num_epochs = 10
for epoch in range(num_epochs):# model.train()running_loss = 0.0correct = 0for x_train, y_train in data_loader:x_train, y_train = x_train.to(device), y_train.to(device)print(x_train.shape, y_train.shape)# 前向传播outputs = model(x_train)loss = criterion(outputs, y_train)# 反向传播和优化optimizer.zero_grad()loss.backward()optimizer.step()running_loss += loss.item() * x_train.size(0)# 统计训练集的准确率_, predicted = torch.max(outputs, 1)correct += (predicted == y_train).sum().item()# 计算每个 epoch 的训练损失和准确率epoch_loss = running_loss / len(dataset)epoch_accuracy = 100 * correct / len(dataset)# if epoch % 5 == 0 or epoch == num_epochs-1 :print(f'Epoch [{epoch+1}/{num_epochs}]')print(f'Train Loss: {epoch_loss:.4f}, Train Accuracy: {epoch_accuracy:.2f}%')## 查看微调后模型的参数
tuned_parm = model.stem[0].weight
print(tuned_parm)
如有更好的方法,欢迎大家分享~
相关文章:
Pytorch微调深度学习模型
在公开数据训练了模型,有时候需要拿到自己的数据上微调。今天正好做了一下微调,在此记录一下微调的方法。用Pytorch还是比较容易实现的。 网上找了很多方法,以及Chatgpt也给了很多方法,但是不够简洁和容易理解。 大体步骤是&…...
springboot 使用笔记
1.springboot 快速启动项目 注意:该启动只是临时启动,不能关闭终端面板 cd /www/wwwroot java -jar admin.jar2.脚本启动 linux shell脚本启动springboot服务 3.java一键部署springboot 第5条 https://blog.csdn.net/qq_30272167/article/details/1…...
网络安全基础——网络安全法
填空题 1.根据**《中华人民共和国网络安全法》**第二十条(第二款),任何组织和个人试用网路应当遵守宪法法律,遵守公共秩序,遵守社会公德,不危害网络安全,不得利用网络从事危害国家安全、荣誉和利益,煽动颠…...
SCAU软件体系结构实验四 组合模式
目录 一、题目 二、源码 一、题目 个人(Person)与团队(Team)可以形成一个组织(Organization):组织有两种:个人组织和团队组织,多个个人可以组合成一个团队,不同的个人与团队可以组合成一个更大的团队。 使用控制台或者JavaFx界面…...
Amazon商品详情API接口:电商创新与用户体验的驱动力
在电子商务蓬勃发展的今天,作为全球最大的电商平台之一,亚马逊(Amazon)凭借其强大的技术实力和丰富的商品资源,为全球用户提供了优质的购物体验。其中,Amazon商品详情API接口在电商创新与用户体验提升方面扮…...
手机无法连接服务器1302什么意思?
你有没有遇到过手机无法连接服务器,屏幕上显示“1302”这样的错误代码?尤其是在急需使用手机进行工作或联系朋友时,突然出现的连接问题无疑会带来不少麻烦。那么,什么是1302错误,它又意味着什么呢? 1302错…...
Android adb shell dumpsys audio 信息查看分析详解
Android adb shell dumpsys audio 信息查看分析详解 一、前言 Android 如果要分析当前设备的声音通道相关日志, 仅仅看AudioService的日志是看不到啥日志的,但是看整个audio关键字的日志又太多太乱了, 所以可以看一下系统提供的一个调试指令…...
Python 网络爬虫操作指南
网络爬虫是自动化获取互联网上信息的一种工具。它广泛应用于数据采集、分析以及实现信息聚合等众多领域。本文将为你提供一个完整的Python网络爬虫操作指南,帮助你从零开始学习并实现简单的网络爬虫。我们将涵盖基本的爬虫概念、Python环境配置、常用库介绍。 上传…...
基于FPGA的2FSK调制-串口收发-带tb仿真文件-实际上板验证成功
基于FPGA的2FSK调制 前言一、2FSK储备知识二、代码分析1.模块分析2.波形分析 总结 前言 设计实现连续相位 2FSK 调制器,2FSK 的两个频率为:fI15KHz,f23KHz,波特率为 1500 bps,比特0映射为f 载波,比特1映射为 载波。 1)…...
JavaScript的基础数据类型
一、JavaScript中的数组 定义 数组是一种特殊的对象,用于存储多个值。在JavaScript中,数组可以包含不同的数据类型,如数字、字符串、对象、甚至其他数组。数组的创建有两种常见方式: 字面量表示法:let fruits [apple…...
第三讲 架构详解:“隐语”可信隐私计算开源框架
目录 隐语架构 隐语架构拆解 产品层 算法层 计算层 资源层 互联互通 跨域管控 本文主要是记录参加隐语开源社区推出的第四期隐私计算实训营学习到的相关内容。 隐语架构 隐语架构拆解 产品层 产品定位: 通过可视化产品,降低终端用户的体验和演…...
JDBC编程---Java
目录 一、数据库编程的前置 二、Java的数据库编程----JDBC 1.概念 2.JDBC编程的优点 三.导入MySQL驱动包 四、JDBC编程的实战 1.创造数据源,并设置数据库所在的位置,三条固定写法 2.建立和数据库服务器之间的连接,连接好了后ÿ…...
Python绘制太极八卦
文章目录 系列目录写在前面技术需求1. 图形绘制库的支持2. 图形绘制功能3. 参数化设计4. 绘制控制5. 数据处理6. 用户界面 完整代码代码分析1. rset() 函数2. offset() 函数3. taiji() 函数4. bagua() 函数5. 绘制过程6. 技术亮点 写在后面 系列目录 序号直达链接爱心系列1Pyth…...
Spring框架特性及包下载(Java EE 学习笔记04)
1 Spring 5的新特性 Spring 5是Spring当前最新的版本,与历史版本对比,Spring 5对Spring核心框架进行了修订和更新,增加了很多新特性,如支持响应式编程等。 更新JDK基线 因为Spring 5代码库运行于JDK 8之上,所以Spri…...
Linux关于vim的笔记
Linux关于vim的笔记:(vimtutor打开vim 教程) --------------------------------------------------------------------------------------------------------------------------------- 1. 光标在屏幕文本中的移动既可以用箭头键,也可以使用 hjkl 字母键…...
linux mount nfs开机自动挂载远程目录
要在Linux系统中实现开机自动挂载NFS共享目录,你需要编辑/etc/fstab文件。以下是具体步骤和示例: 确保你的系统已经安装了NFS客户端。如果没有安装,可以使用以下命令安装: sudo apt-install nfs-common 编辑/etc/fstab文件&#…...
【vue】导航守卫
什么是导航守卫 在vue路由切换过程中对行为做个限制 全局前置守卫 route.beforeEach((to, from, next)) > {// to是切换到的路由// from是正要离开的路由// next控制是否允许进入目标路由next(false); //不允许 }路由级别的导航守卫 const routes [{path: /User,name: U…...
基于Matlab实现LDPC编码
在无线通信和数据存储领域,LDPC(低密度奇偶校验码)编码是一种高效、纠错能力强大的错误校正技术。本MATLAB仿真程序全面地展示了如何在AWGN(加性高斯白噪声)信道下应用LDPC编码与BPSK(二进制相移键控&#…...
PostgreSQL 中约束Constraints
在 PostgreSQL 中,约束(Constraints)是用于限制进入数据库表中数据的规则。它们确保数据的准确性和可靠性,通过定义规则来防止无效数据的插入或更新。PostgreSQL 支持多种类型的约束,每种约束都有特定的用途和语法。以…...
✨系统设计时应时刻考虑设计模式基础原则
目录 💫单一职责原则 (Single Responsibility Principle, SRP)💫开放-封闭原则 (Open-Closed Principle, OCP)💫依赖倒转原则 (Dependency Inversion Principle, DIP)💫里氏代换原则 (Liskov Substitution Principle, LSP)&#x…...
音乐解密技术探秘:从加密困境到跨平台解决方案
音乐解密技术探秘:从加密困境到跨平台解决方案 【免费下载链接】unlock-music 在浏览器中解锁加密的音乐文件。原仓库: 1. https://github.com/unlock-music/unlock-music ;2. https://git.unlock-music.dev/um/web 项目地址: https://gitc…...
【FreeRTOS实战入门】一、从CubeMX到第一个任务:手把手搭建FreeRTOS工程
1. 为什么选择FreeRTOS与CubeMX组合 第一次接触嵌入式实时操作系统时,很多人会纠结选择哪种RTOS。我当年在uC/OS-II和FreeRTOS之间犹豫了很久,最终选择了后者。原因很简单:FreeRTOS不仅完全免费开源,还有STM32CubeMX这个神器加持。…...
Django 学习日记(补充1)| 彻底吃透:自定义 JWT 认证 + 全局登录中间件
大家好,这是我 Django 学习日记的第三篇。上一篇我们把路由、反向解析、DRF 自动路由、媒体文件、跨域全部讲明白了。今天我们进入整个项目最核心、最安全、最关键的部分:用户登录认证体系(在进入视图前的一篇补充文章)。本文将从…...
魔兽世界API开发助手:从新手到专家的全流程解决方案
魔兽世界API开发助手:从新手到专家的全流程解决方案 【免费下载链接】wow_api Documents of wow API -- 魔兽世界API资料以及宏工具 项目地址: https://gitcode.com/gh_mirrors/wo/wow_api 价值定位:如何避免90%的插件开发陷阱? 在魔…...
终极Ghidra安装指南:5分钟在Ubuntu系统快速部署逆向工程神器
终极Ghidra安装指南:5分钟在Ubuntu系统快速部署逆向工程神器 【免费下载链接】ghidra_installer Helper scripts to set up OpenJDK 11 and scale Ghidra for 4K on Ubuntu 18.04 / 18.10 项目地址: https://gitcode.com/gh_mirrors/gh/ghidra_installer 想要…...
COMSOL 探索岩石力学多场景:损伤、压裂、试验与模拟
COMSOL岩石损伤、水力压裂、三轴试验 岩石在膨胀剂的膨胀作用下的损伤; 相场法与水力压裂(6个模型); 不固结不排水三轴试验; 二维钻孔封孔效果模拟。在岩石力学领域,COMSOL 如同一个强大的实验室,让我们能够对复杂的岩…...
如何快速为Obsidian插件添加状态栏功能:完整指南与实用示例
如何快速为Obsidian插件添加状态栏功能:完整指南与实用示例 【免费下载链接】obsidian-sample-plugin 项目地址: https://gitcode.com/GitHub_Trending/ob/obsidian-sample-plugin Obsidian Sample Plugin是一个官方提供的插件开发示例,展示了如…...
知识图谱项目实战(基础概念以及工具使用)【第一章】
在RAG以及Agent的应用领域中,知识图谱可以增强知识库的检索效果(通过搭建知识图谱数据库(GraphRag)实现).在教育医疗以及金融领域应用广泛.图谱(graph)有节点和边组成一.知识图谱理论1.1知识图谱的整体架构1.2知识图谱架构实现流程1. 文本标注(Doccano标…...
Fillinger智能填充脚本终极指南:如何快速实现图形元素的智能分布
Fillinger智能填充脚本终极指南:如何快速实现图形元素的智能分布 【免费下载链接】illustrator-scripts Adobe Illustrator scripts 项目地址: https://gitcode.com/gh_mirrors/il/illustrator-scripts Fillinger是一款专为Adobe Illustrator设计的智能填充脚…...
如何在群晖NAS上部署百度网盘客户端:终极安装与配置指南
如何在群晖NAS上部署百度网盘客户端:终极安装与配置指南 【免费下载链接】synology-baiduNetdisk-package 项目地址: https://gitcode.com/gh_mirrors/sy/synology-baiduNetdisk-package 还在为群晖NAS与百度网盘之间的文件同步问题而烦恼吗?群晖…...
