day46 python预训练模型补充
目录
一、预训练模型的背景知识
二、实验过程
(一)实验环境与数据准备
(二)预训练模型的选择与适配
(三)训练策略
三、实验结果与分析
四、学习总结与展望
一、预训练模型的背景知识
在传统的神经网络训练中,模型的参数是随机初始化的,这可能导致训练初期的不稳定,并且容易陷入局部最优解。而预训练模型的出现,为这一问题提供了有效的解决方案。预训练模型是在大规模数据集(如 ImageNet)上预先训练好的模型,它已经学习到了丰富的通用特征。当我们面临一个新的图像分类任务时,可以直接利用这些预训练好的模型参数来初始化我们的模型,这样模型在初始阶段就具备了一定的特征提取能力,能够更快地收敛,并且在一定程度上避免了局部最优解的问题。
预训练模型的选择至关重要。首先,预训练任务与目标任务的相似性是关键因素。如果两个任务在特征层面具有相似性,那么预训练模型提取的特征将对目标任务更有帮助。其次,预训练数据集的规模也非常重要。大规模的数据集能够支撑模型学习到更通用的特征,从而在不同的任务中具有更好的泛化能力。例如,ImageNet 数据集拥有 1000 个类别,1.2 亿张图像,尺寸为 224x224,是一个非常适合用于预训练的大规模图像数据集。
二、实验过程
(一)实验环境与数据准备
本次实验使用的是 PyTorch 深度学习框架,借助其丰富的预训练模型库和便捷的数据处理工具,能够高效地完成模型的加载、训练和测试。实验中使用的 CIFAR-10 数据集是一个经典的图像分类数据集,包含 10 个类别,共 60000 张 32x32 的彩色图像,其中训练集有 50000 张图像,测试集有 10000 张图像。由于 CIFAR-10 图像的尺寸较小,且类别相对较少,因此直接在该数据集上训练模型可能会面临过拟合等问题,而预训练模型的引入则有望缓解这一问题。
在数据预处理阶段,为了增强模型的泛化能力,对训练集进行了多种数据增强操作,包括随机裁剪、随机水平翻转、颜色抖动以及随机旋转等。这些操作能够在训练过程中为模型提供更多的“干扰”或变形,使模型能够学习到更加鲁棒的特征。具体的数据预处理代码如下:
import torch
import torch.nn as nn
import torch.optim as optim
from torchvision import datasets, transforms
from torch.utils.data import DataLoader
import matplotlib.pyplot as plt# 设置中文字体支持
plt.rcParams["font.family"] = ["SimHei"]
plt.rcParams['axes.unicode_minus'] = False # 解决负号显示问题# 检查 GPU 是否可用
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
print(f"使用设备: {device}")# 数据预处理(训练集增强,测试集标准化)
train_transform = transforms.Compose([transforms.RandomCrop(32, padding=4),transforms.RandomHorizontalFlip(),transforms.ColorJitter(brightness=0.2, contrast=0.2, saturation=0.2, hue=0.1),transforms.RandomRotation(15),transforms.ToTensor(),transforms.Normalize((0.4914, 0.4822, 0.4465), (0.2023, 0.1994, 0.2010))
])test_transform = transforms.Compose([transforms.ToTensor(),transforms.Normalize((0.4914, 0.4822, 0.4465), (0.2023, 0.1994, 0.2010))
])# 加载 CIFAR-10 数据集
train_dataset = datasets.CIFAR10(root='./data',train=True,download=True,transform=train_transform
)test_dataset = datasets.CIFAR10(root='./data',train=False,transform=test_transform
)# 创建数据加载器
batch_size = 64
train_loader = DataLoader(train_dataset, batch_size=batch_size, shuffle=True)
test_loader = DataLoader(test_dataset, batch_size=batch_size, shuffle=False)
对于测试集,则仅进行了标准化处理,以确保其与训练集在数据分布上具有一致性。
(二)预训练模型的选择与适配
在本次实验中,选择了 ResNet18 作为预训练模型。ResNet18 是一种经典的卷积神经网络架构,具有 18 层深度,并且通过残差连接解决了深层网络训练中的梯度消失问题。它在 ImageNet 数据集上预训练后,能够提取出具有较强表达能力的特征。
由于 ResNet18 在 ImageNet 数据集上预训练时,其输入图像尺寸为 224x224,而 CIFAR-10 图像的尺寸为 32x32,因此需要对模型进行适当的调整以适配 CIFAR-10 数据集。具体来说,需要修改模型的最后一层全连接层,将其输出类别数从 1000 改为 10,以匹配 CIFAR-10 的类别数量。此外,由于输入图像尺寸的变化,还需要调整模型中的一些层的参数,例如池化层的步长等。以下是 ResNet18 模型的适配代码:
from torchvision.models import resnet18# 定义 ResNet18 模型(支持预训练权重加载)
def create_resnet18(pretrained=True, num_classes=10):model = resnet18(pretrained=pretrained)in_features = model.fc.in_featuresmodel.fc = nn.Linear(in_features, num_classes)return model.to(device)# 创建 ResNet18 模型(加载 ImageNet 预训练权重,不进行微调)
model = create_resnet18(pretrained=True, num_classes=10)
model.eval() # 设置为推理模式
(三)训练策略
在训练过程中,采用了阶段式训练策略。首先,冻结模型的卷积层参数,仅训练全连接层,这样可以在不破坏预训练模型特征提取能力的前提下,快速调整模型的输出层以适应 CIFAR-10 数据集。经过一定轮次的训练后,解冻模型的所有参数,进行整体训练,以进一步提升模型的性能。这种策略能够在训练初期快速降低损失,并在后续训练中充分利用预训练模型的特征提取能力,实现更好的收敛效果。
具体来说,实验中设置了前 5 轮冻结卷积层参数,之后解冻所有参数进行训练。在解冻后,为了防止过拟合,还适当降低了学习率。以下是训练函数的关键代码:
# 冻结/解冻模型层的函数
def freeze_model(model, freeze=True):"""冻结或解冻模型的卷积层参数"""for name, param in model.named_parameters():if 'fc' not in name:param.requires_grad = not freezefrozen_params = sum(p.numel() for p in model.parameters() if not p.requires_grad)total_params = sum(p.numel() for p in model.parameters())if freeze:print(f"已冻结模型卷积层参数 ({frozen_params}/{total_params} 参数)")else:print(f"已解冻模型所有参数 ({total_params}/{total_params} 参数可训练)")return model# 训练函数(支持阶段式训练)
def train_with_freeze_schedule(model, train_loader, test_loader, criterion, optimizer, scheduler, device, epochs, freeze_epochs=5):model = freeze_model(model, freeze=True)for epoch in range(epochs):if epoch == freeze_epochs:model = freeze_model(model, freeze=False)optimizer.param_groups[0]['lr'] = 1e-4 # 解冻后调整学习率model.train()running_loss = 0.0correct_train = 0total_train = 0for batch_idx, (data, target) in enumerate(train_loader):data, target = data.to(device), target.to(device)optimizer.zero_grad()output = model(data)loss = criterion(output, target)loss.backward()optimizer.step()running_loss += loss.item()_, predicted = output.max(1)total_train += target.size(0)correct_train += predicted.eq(target).sum().item()epoch_train_loss = running_loss / len(train_loader)epoch_train_acc = 100. * correct_train / total_trainmodel.eval()correct_test = 0total_test = 0test_loss = 0.0with torch.no_grad():for data, target in test_loader:data, target = data.to(device), target.to(device)output = model(data)test_loss += criterion(output, target).item()_, predicted = output.max(1)total_test += target.size(0)correct_test += predicted.eq(target).sum().item()epoch_test_loss = test_loss / len(test_loader)epoch_test_acc = 100. * correct_test / total_testif scheduler is not None:scheduler.step(epoch_test_loss)print(f"Epoch {epoch+1} 完成 | 训练损失: {epoch_train_loss:.4f} | 训练准确率: {epoch_train_acc:.2f}% | 测试准确率: {epoch_test_acc:.2f}%")
三、实验结果与分析
经过 40 轮的训练,最终测试准确率达到了 86.30%。从训练过程的输出中可以看出,在解冻卷积层参数后,模型的训练损失迅速下降,训练准确率和测试准确率都得到了显著提升。这充分证明了预训练模型的强大优势,即使在 CIFAR-10 这种相对较小的数据集上,也能够通过微调取得优异的性能。
此外,由于训练集采用了数据增强操作,模型在训练初期可能会出现训练准确率暂时低于测试准确率的情况。这是因为数据增强增加了模型训练的难度,而测试集是标准的、未增强的图像,模型在测试集上预测相对轻松。随着训练的推进,模型逐渐适应了数据增强带来的变化,训练准确率和测试准确率之间的差距逐渐缩小。
以下是完整的训练代码:
# 主函数:训练模型
def main():# 参数设置epochs = 40 # 总训练轮次freeze_epochs = 5 # 冻结卷积层的轮次learning_rate = 1e-3 # 初始学习率weight_decay = 1e-4 # 权重衰减# 创建 ResNet18 模型(加载预训练权重)model = create_resnet18(pretrained=True, num_classes=10)# 定义优化器和损失函数optimizer = optim.Adam(model.parameters(), lr=learning_rate, weight_decay=weight_decay)criterion = nn.CrossEntropyLoss()# 定义学习率调度器scheduler = optim.lr_scheduler.ReduceLROnPlateau(optimizer, mode='min', factor=0.5, patience=2, verbose=True)# 开始训练(前 5 轮冻结卷积层,之后解冻)final_accuracy = train_with_freeze_schedule(model=model,train_loader=train_loader,test_loader=test_loader,criterion=criterion,optimizer=optimizer,scheduler=scheduler,device=device,epochs=epochs,freeze_epochs=freeze_epochs)print(f"训练完成!最终测试准确率: {final_accuracy:.2f}%")if __name__ == "__main__":main()
@浙大疏锦行
补充-60日计划day44,pynote中day55
相关文章:
day46 python预训练模型补充
目录 一、预训练模型的背景知识 二、实验过程 (一)实验环境与数据准备 (二)预训练模型的选择与适配 (三)训练策略 三、实验结果与分析 四、学习总结与展望 一、预训练模型的背景知识 在传统的神经网…...
CCPC chongqing 2025 H
题目链接:https://codeforces.com/gym/105887 题目背景: 方框上有上下两排小球,下面的紧贴框底,上面的部分贴框顶,每牌小球上都有一个一个数字(1~n),将相同的小球连接到一起,是否在不交叉的情况…...

Java建造者模式(Builder Pattern)详解与实践
一、引言 在软件开发中,我们经常会遇到需要创建复杂对象的场景。例如,构建一个包含多个可选参数的对象时,传统的构造函数或Setter方法可能导致代码臃肿、难以维护。此时,建造者模式(Builder Pattern)便成为…...
ant-design4.xx实现数字输入框; 某些输入法数字需要连续输入两次才显示
目录 一、问题 二、解决方法 三、总结 一、问题 1.代码里有一个基于ant封装的公共组件数字输入框,测试突然说 无效了,输入其他字符也会显示;改了只有又发现某些 输入法 需要连续输入两次 才能显示出来。 二、解决方法 1.就离谱࿰…...
使用ORM Bee (ormbee) ,如何利用SQLAlchemy的模型生成数据库表.
使用ORM Bee (ormbee) ,如何利用SQLAlchemy的模型生成数据库表. 将原来SQLAlchemy的模型,修改依赖为: from bee.helper import SQLAlchemy 然后就可以开始生成了。很简单,主要是两个接口。 db.create_all(True) #创建所有模型的表…...
【win | 自动更新关闭】win11
利用本地组策略编辑器 对于Windows 11专业版或更高版本的用户,可以利用本地组策略编辑器来完全关闭自动更新。按下“WinR”键,输入“gpedit.msc”并回车。在本地组策略编辑器中,依次展开“计算机配置”>“管理模板”>“Windows组件”&…...

win32相关(IAT HOOK)
IAT HOOK 什么是IAT Hook? IAT Hook(Import Address Table Hook,导入地址表钩子)是一种Windows平台下的API钩取技术,通过修改目标程序的导入地址表(IAT)来拦截和重定向API调用 在我们之前学习pe文件结构的导入表时&am…...
大模型高效提示词Prompt编写指南
大模型高效Prompt编写指南 一、引言二、核心原则1. 清晰性原则:明确指令与期望2. 具体性原则:提供详细上下文3. 结构化原则:组织信息的逻辑与层次4. 迭代优化原则:通过反馈改进Prompt5. 简洁性原则:避免冗余信息 三、文…...

零基础玩转物联网-串口转以太网模块如何快速实现与TCP服务器通信
目录 1 前言 2 环境搭建 2.1 硬件准备 2.2 软件准备 2.3 驱动检查 3 TCP服务器通信配置与交互 3.1 硬件连接 3.2 开启TCP服务器 3.3 打开配置工具读取基本信息 3.4 填写连接参数进行连接 3.5 通信测试 4 总结 1 前言 TCP是TCP/IP体系中的传输层协议,全称为Transmiss…...
十一、【ESP32开发全栈指南: TCP通信服务端】
一、TCP与UDP协议对比 1.1 基本特性比较 TCP(传输控制协议)和UDP(用户数据报协议)是两种最常用的传输层协议,它们在ESP32网络编程中都有广泛应用: 连接方式 TCP是面向连接的协议,通信前需要先建立连接(三次握手)UDP是无连接的协议ÿ…...

ESP32开发之LED闪烁和呼吸的实现
硬件电路介绍GPIO输出模式GPIO配置过程闪烁灯的源码LED PWM的控制器(LEDC)概述LEDC配置过程及现象整体流程 硬件电路介绍 电路图如下: 只要有硬件基础的应该都知道上图中,当GPIO4的输出电平为高时,LED灯亮,反之则熄灭。如果每间…...

【产品业务设计】支付业务设计规范细节记录,含订单记录、支付业务记录、支付流水记录、退款业务记录
【产品业务设计】支付业务设计规范细节记录,含订单记录、支付业务记录、支付流水记录 前言 我为什么要写这个篇文章 总结设计经验生成设计模板方便后期快速搭建 一个几张表 一共5张表; 分别是: 订单主表:jjy_orderMain订单产…...

2025软件供应链安全最佳实践︱证券DevSecOps下供应链与开源治理实践
项目背景:近年来,云计算、AI人工智能、大数据等信息技术的不断发展、各行各业的信息电子化的步伐不断加快、信息化的水平不断提高,网络安全的风险不断累积,金融证券行业面临着越来越多的威胁挑战。特别是近年以来,开源…...
Linux安装jdk、tomcat
1、安装jdk sudo yum install -y java-1.8.0-openjdk-devel碰到的问题:/var/run/yum.pid 已被锁定 Another app is currently holding the yum lock; waiting for it to exit… https://blog.csdn.net/u013669912/article/details/131259156 参考&#…...

WebRTC通话原理与入门难度实战指南
波煮的实习公司主要是音视频业务,所以最近在补习WebRTC的相关内容,会不定期给大家分享学习心得和笔记。 文章目录 WebRTC通话原理进行媒体协商:彼此要了解对方支持的媒体格式网络协商:彼此要了解对方的网络情况,这样才…...

N元语言模型 —— 一文讲懂!!!
目录 引言 一. 基本知识 二.参数估计 三.数据平滑 一.加1法 二.减值法/折扣法 编辑 1.Good-Turing 估计 编辑 2.Back-off (后备/后退)方法 3.绝对减值法 编辑4.线性减值法 5.比较 三.删除插值法(Deleted interpolation) 四.模型自适应 引言 本章节讲的…...

.NET 9中的异常处理性能提升分析:为什么过去慢,未来快
一、为什么要关注.NET异常处理的性能 随着现代云原生、高并发、分布式场景的大量普及,异常处理(Exception Handling)早已不再只是一个冷僻的代码路径。在高复杂度的微服务、网络服务、异步编程环境下,服务依赖的外部资源往往不可…...

Mac 安装git心路历程(心累版)
省流版:直接安装Xcode命令行工具即可,不用安Xcode。 git下载官网 第一部分 上网初步了解后,打算直接安装Binary installer,下载完安装时,苹果还阻止安装,只好在“设置–安全性与隐私”最下面的提示进行安…...

计算机网络第2章(下):物理层传输介质与核心设备全面解析
目录 一、传输介质1.1 传输介质的分类1.2 导向型传输介质1.2.1 双绞线(Twisted Pair)1.2.2 同轴电缆(Coaxial Cable)1.2.3 光纤(Optical Fiber)1.2.4 以太网对有线传输介质的命名规则 1.3 非导向型传输介质…...
Qt Creator 11.0创建ROS2 Humble工程
Qt Creator 11.0创建ROS2 Humble项目工程 安装ROSProjectManager插件创建ROS2项目在src下添加packagegit clone ROS2功能包编译运行安装ROSProjectManager插件 安装ROSProjectManager的主要流程参考官方的流程,地址(ros_qtc_plugin)。 此处采用二进制安装: sudo apt inst…...

C# 类和继承(扩展方法)
扩展方法 在迄今为止的内容中,你看到的每个方法都和声明它的类关联。扩展方法特性扩展了这个边 界,允许编写的方法和声明它的类之外的类关联。 想知道如何使用这个特性,请看下面的代码。它包含类MyData,该类存储3个double类型 的…...
机器学习复习3--模型的选择
选择合适的机器学习模型是机器学习项目成功的关键一步。这通常不是一个一蹴而就的过程,而是需要综合考虑多个因素,并进行实验和评估。 1. 理解问题本质 这是模型选择的首要步骤。需要清晰地定义试图解决的问题类型: 监督学习 : 数据集包含…...

MySQL复杂SQL(多表联查/子查询)详细讲解
🤟致敬读者 🟩感谢阅读🟦笑口常开🟪生日快乐⬛早点睡觉 📘博主相关 🟧博主信息🟨博客首页🟫专栏推荐🟥活动信息 文章目录 MySQL复杂SQL(多表联查/子查询&a…...

STM32使用土壤湿度传感器
1.1 介绍: 土壤湿度传感器是一种传感装置,主要用于检测土壤湿度的大小,并广泛应用于汽车自动刮水系统、智能灯光系统和智能天窗系统等。传感器采用优质FR-04双料,大面积5.0 * 4.0厘米,镀镍处理面。 它具有抗氧化&…...
在C++中,头文件(.h或.hpp)的标准写法
目录 1.头文件保护(Include Guards)2.包含必要的标准库头文件3.前向声明(Forward Declarations)4.命名空间5.注释示例1:基础头文件示例2:包含模板和内联函数的头文件示例3:C11风格的枚举类头文件…...
Axios学习笔记
Axios简介 axios前端异步请求库类似JQuery ajax技术, ajax用来在页面发起异步请求到后端服务,并将后端服务响应数据渲染到页面上, jquery推荐ajax技术,但vue里面并不推荐在使用jquery框架,vue推荐使用axios异步请求库。…...
Langchain学习笔记(十一):Chain构建与组合技巧
注:本文是Langchain框架的学习笔记;不是教程!不是教程!内容可能有所疏漏,欢迎交流指正。后续将持续更新学习笔记,分享我的学习心得和实践经验。 前言 在LangChain的发展过程中,API设计经历了重…...
【判断既约分数】2022-4-3
缘由既约分数,除了辗转相除法-编程语言-CSDN问答 void 判断既约分数() {int a 1, b 2020, aa b, y 2, gs 0;while (aa){while (a < b){while (y < a && y < aa)if (a%y 0 && aa%y 0)a, y 2;elsey;if (a < b)gs; else;a, y 2;…...

Windows平台RTSP/RTMP播放器C#接入详解
大牛直播SDK在Windows平台下的RTSP、RTMP播放器模块,基于自研高性能内核,具备极高的稳定性与行业领先的超低延迟表现。相比传统基于FFmpeg或VLC的播放器实现,SmartPlayer不仅支持RTSP TCP/UDP自动切换、401鉴权、断网重连等网络复杂场景自适应…...
深圳SMT贴片工艺优化关键步骤
内容概要 深圳SMT贴片工艺优化作为现代电子制造的核心环节,聚焦于提升生产精度与稳定性。其技术框架围绕三大核心维度展开:温度动态调控、设备协同适配与工艺缺陷预判。通过精密温度曲线控制系统,实现回流焊环节的热能梯度精准匹配ÿ…...