当前位置: 首页 > article >正文

day 44

使用DenseNet预训练模型对cifar10数据集进行训练

import torch
import torch.nn as nn
import torch.optim as optim
from torchvision import datasets, transforms, models
from torch.utils.data import DataLoader
import matplotlib.pyplot as plt
import os# 设置中文字体支持
plt.rcParams["font.family"] = ["SimHei"]
plt.rcParams['axes.unicode_minus'] = False  # 解决负号显示问题# 检查GPU是否可用
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
print(f"使用设备: {device}")# 1. 数据预处理(训练集增强,测试集标准化)
train_transform = transforms.Compose([transforms.RandomCrop(32, padding=4),transforms.RandomHorizontalFlip(),transforms.ColorJitter(brightness=0.2, contrast=0.2, saturation=0.2, hue=0.1),transforms.RandomRotation(15),transforms.ToTensor(),transforms.Normalize((0.4914, 0.4822, 0.4465), (0.2023, 0.1994, 0.2010))
])test_transform = transforms.Compose([transforms.ToTensor(),transforms.Normalize((0.4914, 0.4822, 0.4465), (0.2023, 0.1994, 0.2010))
])# 2. 加载CIFAR-10数据集
train_dataset = datasets.CIFAR10(root='./data',train=True,download=True,transform=train_transform
)test_dataset = datasets.CIFAR10(root='./data',train=False,transform=test_transform
)# 3. 创建数据加载器
batch_size = 64
train_loader = DataLoader(train_dataset, batch_size=batch_size, shuffle=True)
test_loader = DataLoader(test_dataset, batch_size=batch_size, shuffle=False)# 4. 定义DenseNet模型
def create_densenet(pretrained=True, num_classes=10):# 修改为加载DenseNet121预训练模型model = models.densenet121(pretrained=pretrained)# 修改最后一层全连接层in_features = model.classifier.in_featuresmodel.classifier = nn.Linear(in_features, num_classes)return model.to(device)# 5. 冻结/解冻模型层的函数
def freeze_model(model, freeze=True):"""冻结或解冻模型的卷积层参数"""# 冻结/解冻除classifier层外的所有参数for name, param in model.named_parameters():if 'classifier' not in name:param.requires_grad = not freeze# 打印冻结状态frozen_params = sum(p.numel() for p in model.parameters() if not p.requires_grad)total_params = sum(p.numel() for p in model.parameters())if freeze:print(f"已冻结模型卷积层参数 ({frozen_params}/{total_params} 参数)")else:print(f"已解冻模型所有参数 ({total_params}/{total_params} 参数可训练)")return model# 6. 训练函数(支持阶段式训练)
def train_with_freeze_schedule(model, train_loader, test_loader, criterion, optimizer, scheduler, device, epochs, freeze_epochs=5):"""前freeze_epochs轮冻结卷积层,之后解冻所有层进行训练"""train_loss_history = []test_loss_history = []train_acc_history = []test_acc_history = []all_iter_losses = []iter_indices = []# 初始冻结卷积层if freeze_epochs > 0:model = freeze_model(model, freeze=True)for epoch in range(epochs):# 解冻控制:在指定轮次后解冻所有层if epoch == freeze_epochs:model = freeze_model(model, freeze=False)# 解冻后调整优化器(可选)optimizer.param_groups[0]['lr'] = 1e-4  # 降低学习率防止过拟合model.train()  # 设置为训练模式running_loss = 0.0correct_train = 0total_train = 0for batch_idx, (data, target) in enumerate(train_loader):data, target = data.to(device), target.to(device)optimizer.zero_grad()output = model(data)loss = criterion(output, target)loss.backward()optimizer.step()# 记录Iteration损失iter_loss = loss.item()all_iter_losses.append(iter_loss)iter_indices.append(epoch * len(train_loader) + batch_idx + 1)# 统计训练指标running_loss += iter_loss_, predicted = output.max(1)total_train += target.size(0)correct_train += predicted.eq(target).sum().item()# 每100批次打印进度if (batch_idx + 1) % 100 == 0:print(f"Epoch {epoch+1}/{epochs} | Batch {batch_idx+1}/{len(train_loader)} "f"| 单Batch损失: {iter_loss:.4f}")# 计算 epoch 级指标epoch_train_loss = running_loss / len(train_loader)epoch_train_acc = 100. * correct_train / total_train# 测试阶段model.eval()correct_test = 0total_test = 0test_loss = 0.0with torch.no_grad():for data, target in test_loader:data, target = data.to(device), target.to(device)output = model(data)test_loss += criterion(output, target).item()_, predicted = output.max(1)total_test += target.size(0)correct_test += predicted.eq(target).sum().item()epoch_test_loss = test_loss / len(test_loader)epoch_test_acc = 100. * correct_test / total_test# 记录历史数据train_loss_history.append(epoch_train_loss)test_loss_history.append(epoch_test_loss)train_acc_history.append(epoch_train_acc)test_acc_history.append(epoch_test_acc)# 更新学习率调度器if scheduler is not None:scheduler.step(epoch_test_loss)# 打印 epoch 结果print(f"Epoch {epoch+1} 完成 | 训练损失: {epoch_train_loss:.4f} "f"| 训练准确率: {epoch_train_acc:.2f}% | 测试准确率: {epoch_test_acc:.2f}%")# 绘制损失和准确率曲线plot_iter_losses(all_iter_losses, iter_indices)plot_epoch_metrics(train_acc_history, test_acc_history, train_loss_history, test_loss_history)return epoch_test_acc  # 返回最终测试准确率# 7. 绘制Iteration损失曲线
def plot_iter_losses(losses, indices):plt.figure(figsize=(10, 4))plt.plot(indices, losses, 'b-', alpha=0.7)plt.xlabel('Iteration(Batch序号)')plt.ylabel('损失值')plt.title('训练过程中的Iteration损失变化')plt.grid(True)plt.show()# 8. 绘制Epoch级指标曲线
def plot_epoch_metrics(train_acc, test_acc, train_loss, test_loss):epochs = range(1, len(train_acc) + 1)plt.figure(figsize=(12, 5))# 准确率曲线plt.subplot(1, 2, 1)plt.plot(epochs, train_acc, 'b-', label='训练准确率')plt.plot(epochs, test_acc, 'r-', label='测试准确率')plt.xlabel('Epoch')plt.ylabel('准确率 (%)')plt.title('准确率随Epoch变化')plt.legend()plt.grid(True)# 损失曲线plt.subplot(1, 2, 2)plt.plot(epochs, train_loss, 'b-', label='训练损失')plt.plot(epochs, test_loss, 'r-', label='测试损失')plt.xlabel('Epoch')plt.ylabel('损失值')plt.title('损失值随Epoch变化')plt.legend()plt.grid(True)plt.tight_layout()plt.show()# 主函数:训练模型
def main():# 参数设置epochs = 40  # 总训练轮次freeze_epochs = 5  # 冻结卷积层的轮次learning_rate = 1e-3  # 初始学习率weight_decay = 1e-4  # 权重衰减# 创建DenseNet模型(加载预训练权重)model = create_densenet(pretrained=True, num_classes=10)# 定义优化器和损失函数optimizer = optim.Adam(model.parameters(), lr=learning_rate, weight_decay=weight_decay)criterion = nn.CrossEntropyLoss()# 定义学习率调度器scheduler = optim.lr_scheduler.ReduceLROnPlateau(optimizer, mode='min', factor=0.5, patience=2)# 开始训练(前5轮冻结卷积层,之后解冻)final_accuracy = train_with_freeze_schedule(model=model,train_loader=train_loader,test_loader=test_loader,criterion=criterion,optimizer=optimizer,scheduler=scheduler,device=device,epochs=epochs,freeze_epochs=freeze_epochs)print(f"训练完成!最终测试准确率: {final_accuracy:.2f}%")# # 保存模型# torch.save(model.state_dict(), 'resnet18_cifar10_finetuned.pth')# print("模型已保存至: resnet18_cifar10_finetuned.pth")if __name__ == "__main__":main()

注意:

1.DenseNet 最后一层全连接层的名称是 classifier ,并非 fc 因此在 freeze_model 函数里,要注意改一下名。

2.残差连接

在传统的神经网络中,每一层的输出都是对上一层输入进行一系列非线性变换的结果。而在残差网络(ResNet)中,引入了残差块(Residual Block),残差块通过残差连接将输入直接加到经过非线性变换后的输出上。

假设一个神经网络层的输入为 $x$,期望学习的映射为 $H(x)$,在传统网络中,该层需要直接学习 $H(x)$。而在残差网络中,将该层改为学习残差函数 $F(x) = H(x) - x$,则输出变为 $y = F(x) + x$,这里的 $x$ 到 $y$ 的连接就是残差连接。

@浙大疏锦行

相关文章:

day 44

使用DenseNet预训练模型对cifar10数据集进行训练 import torch import torch.nn as nn import torch.optim as optim from torchvision import datasets, transforms, models from torch.utils.data import DataLoader import matplotlib.pyplot as plt import os# 设置中文字体…...

鸿蒙OSUniApp开发跨平台AR扫描识别应用:HarmonyOS实践指南#三方框架 #Uniapp

UniApp开发跨平台AR扫描识别应用:HarmonyOS实践指南 前言 随着增强现实(AR)技术在移动应用中的广泛应用,越来越多的开发者需要在跨平台应用中实现AR功能。本文将深入探讨如何使用UniApp框架开发一个高性能的AR扫描识别应用&…...

NER实践总结,记录一下自己实践遇到的各种问题。

更。 没卡,跑个模型休息好几天,又闲又急。 一开始直接套用了别人的代码进行实体识别,结果很差,原因是他的词表没有我需要的东西,我是用的医学文本。代码直接在github找了改的,用的是BERT的Chinese版本。 然…...

微信小程序实现运动能耗计算

微信小程序实现运动能耗计算 近我做了一个挺有意思的微信小程序,能够实现运动能耗的计算。只需要输入性别、年龄、体重、运动时长和运动类型这些信息,就能算出对应的消耗热量。 具体来说,在小程序里,性别不同,身体基…...

iTunes 无法备份 iPhone:10 种解决方法

Apple 设备是移动设备市场上最先进的产品之一,但有些人遇到过 iTunes 因出现错误而无法备份 iPhone 的情况。iTunes 拒绝备份 iPhone 时,可能会令人非常沮丧。不过,幸运的是,我们有 10 种有效的方法可以解决这个问题。您可以按照以…...

施耐德特价型号伺服电机VIA0703D31A1022、常见故障

⚙️ ‌一、启动类故障‌ ‌电机无法启动‌ ‌可能原因‌:电源未接通、制动器未释放、接线错误或控制器故障。‌解决措施‌: 检查电源线路及断路器状态;验证制动器是否打开(带制动器型号);核对电机与控制器…...

LangChain4J 使用实践

这里写目录标题 大模型应用场景&#xff1a;创建一个测试示例AIService聊天记忆实现简单实现聊天记录记忆MessageWindowChatMemory实现聊天记忆 隔离聊天记忆聊天记忆持久化 添加AI提示词 大模型应用场景&#xff1a; 创建一个测试示例 导入依赖 <dependency><groupI…...

慢SQL调优(二):大表查询

最近在工作中写SQL出现几次慢SQL的BUG&#xff0c;总结下来归根到底就是因为大表的原因~这表有多大呢&#xff0c;执行 select COUNT(1) FROM position 是出不来结果滴&#xff0c;每天保底新增1000条数据&#xff0c;可想而知有多大了&#xff0c;所以多次踩坑了这张表。所以…...

【C++】—— 从零开始封装 Map 与 Set:实现与优化

人生的态度是&#xff0c;抱最大的希望&#xff0c;尽最大的努力&#xff0c;做最坏的打算。 —— 柏拉图 《理想国》 目录 1、理论基石——深度剖析 BSTree、AVLTree 与 RBTree 的概念区别 2、迭代器机制——RBTree 迭代器的架构与工程实现 3、高级容器设计——Map 与 Set…...

内网穿透之Linux版客户端安装(神卓互联)

选择Linux系统版本 获取安装包 &#xff1a;https://www.shenzhuohl.com/download.html 这里以Ubuntu 18.04为例&#xff0c;其它版本方法类似 登录Ubuntu操作系统&#xff1a; 打开Ubuntu系统终端&#xff0c;更新版本 apt-get update 安装运行环境&#xff1a; 安装C 运…...

开疆智能Profinet转Profibus网关连接CMDF5-8ADe分布式IO配置案例

本案例是客户通过开疆智能研发的Profinet转Profibus网关将PLC的Profinet协议数据转换成IO使用的Profibus协议&#xff0c;操作步骤如下。 配置过程&#xff1a; Profinet一侧设置 1. 打开西门子组态软件进行组态&#xff0c;导入网关在Profinet一侧的GSD文件。 2. 新建项目并…...

华为云Flexus+DeepSeek征文|Flexus云服务器单机部署+CCE容器高可用部署快速搭建生产级的生成式AI应用

前引&#xff1a; 在AI技术高速演进的浪潮中&#xff0c;如何快速、高效、安全地搭建一个大模型应用平台&#xff0c;成为开发者和企业关注的焦点。近日&#xff0c;华为云推出的Flexus云服务器配合CCE容器引擎和Dify LLM应用开发平台&#xff0c;带来了极具吸引力的解决方案。…...

扫地机产品--材质传感器算法开发与虚拟示波器

扫地机产品–材质传感器算法开发与虚拟示波器 文章目录 扫地机产品--材质传感器算法开发与虚拟示波器**一、材质传感器的工作原理**二、核心功能与应用场景三、技术参数与产品示例四.MCU 与压电陶瓷超声波的材质检测技术方案实现原理分析4.1 超声波原理4.2表面类型检测4.3 超声…...

[蓝桥杯]上三角方阵

上三角方阵 题目描述 方阵的主对角线之上称为"上三角"。 请你设计一个用于填充 nn 阶方阵的上三角区域的程序。填充的规则是&#xff1a;使用 1&#xff0c;2&#xff0c;3.... 的自然数列&#xff0c;从左上角开始&#xff0c;按照顺时针方向螺旋填充。 例如&am…...

60天python训练计划----day44

DAY 44 预训练模型 知识点回顾&#xff1a; 预训练的概念常见的分类预训练模型图像预训练模型的发展史预训练的策略预训练代码实战&#xff1a;resnet18 一、预训练的概念 我们之前在训练中发现&#xff0c;准确率最开始随着epoch的增加而增加。随着循环的更新&#xff0c;参数…...

【JAVA版】意象CRM客户关系管理系统+uniapp全开源

一.介绍 CRM意象客户关系管理系统&#xff0c;是一个综合性的客户管理平台&#xff0c;旨在帮助企业高效地管理客户信息、商机、合同以及员工业绩。系统通过首页、系统管理、工作流程、审批中心、线索管理、客户管理、商机管理、合同管理、CRM系统、数据统计和系统配置等模块&…...

API异常信息如何实时发送到钉钉

#背景 对于一些重要的API&#xff0c;开发人员会非常关注API有没有报错&#xff0c;为了方便开发人员第一时间获取错误信息&#xff0c;我们可以使用插件来将API报错实时发送到钉钉群。 接下来我们就来实操如何实现 #准备工作 #创建钉钉群 如果已有钉钉群&#xff0c;可以跳…...

Python爬虫(48)基于Scrapy-Redis与深度强化学习的智能分布式爬虫架构设计与实践

目录 一、背景与行业痛点二、核心技术架构设计2.1 分布式爬虫基础架构2.2 深度强化学习模块 三、生产环境实践案例3.1 电商价格监控系统3.2 学术文献采集系统 四、高级优化技术4.1 联邦学习增强4.2 神经架构搜索&#xff08;NAS&#xff09; 五、总结&#x1f308;Python爬虫相…...

AtCoder Beginner Contest 407 E - Most Valuable Parentheses

AtCoder Beginner Contest 407 E - Most Valuable Parentheses E - Most Valuable Parentheses 反悔贪心算法 性质&#xff1a; 假设长度为 n n n&#xff0c; n ≡ 0 ( m o d 2 ) n \equiv 0 \pmod{2} n≡0(mod2) 的括号序列是合法的&#xff0c;那么有 n 2 \frac{n}{2}…...

(1-6-3)Java 多线程

目录 0.知识拓扑 1. 多线程相关概念 1.1 进程 1.2 线程 1.3 java 中的进程 与 线程概述 1.4 CPU、进程 与 线程的关系 2.多线程的创建方式 2.1 继承Thread类 2.2 实现Runnable接口 2.3 实现Callable接口 2.4 三种创建方式对比 3.线程同步 3.1 线程同步机制概述 …...

java31

1.网络编程 三要素&#xff1a; 网址实质上就是ip InetAddress: UDP通信程序&#xff1a; 多个接收端的地址都要加入同一个组播地址&#xff0c;这样发送端发信息&#xff0c;全部接收端都能接受到数据 广播的代码差不多&#xff0c;就是地址不一样而已 TCP通信程序&#xf…...

多模态之智能数字人

多模态下智能数字人的开发是一个复杂且系统性的工程,它融合了人工智能(AI)、计算机图形学、自然语言处理(NLP)、语音技术、计算机视觉(CV)等多个前沿领域。 多模态下智能数字人的开发流程规范 目标: 构建一个能够理解并生成多模态信息(文本、语音、视觉等),具备智…...

界面组件DevExpress WPF中文教程:Grid - 如何识别行和卡片?

DevExpress WPF拥有120个控件和库&#xff0c;将帮助您交付满足甚至超出企业需求的高性能业务应用程序。通过DevExpress WPF能创建有着强大互动功能的XAML基础应用程序&#xff0c;这些应用程序专注于当代客户的需求和构建未来新一代支持触摸的解决方案。 无论是Office办公软件…...

【HarmonyOS Next之旅】DevEco Studio使用指南(三十)

目录 1 -> 部署云侧工程 2 -> 通过CloudDev面板获取云开发资源支持 3 -> 通用云开发模板 3.1 -> 适用范围 3.2 -> 效果图 4 -> 总结 1 -> 部署云侧工程 可以选择在云函数和云数据库全部开发完成后&#xff0c;将整个云工程资源统一部署到AGC云端。…...

AI基础知识(LLM、prompt、rag、embedding、rerank、mcp、agent、多模态)

AI基础知识&#xff08;LLM、prompt、rag、embedding、rerank、mcp、agent、多模态&#xff09; 1、LLM大语言模型 --基于​​深度学习技术​​&#xff0c;通过​​海量文本数据训练​​而成的超大规模人工智能模型&#xff0c;能够理解、生成和推理自然语言文本 --产品&…...

[蓝桥杯]高僧斗法

高僧斗法 题目描述 古时丧葬活动中经常请高僧做法事。仪式结束后&#xff0c;有时会有"高僧斗法"的趣味节目&#xff0c;以舒缓压抑的气氛。 节目大略步骤为&#xff1a;先用粮食&#xff08;一般是稻米&#xff09;在地上"画"出若干级台阶&#xff08;…...

pycharm F2 修改文件名 修改快捷键

菜单&#xff1a;File-> Setting&#xff0c; Keymap中搜索 Rename&#xff0c; 其中&#xff0c;有 Refactor-> Rename&#xff0c;右键添加快捷键&#xff0c;F2&#xff0c;删除原有快捷键就可以了。...

Python Flask中启用AWS Secrets Manager+AWS Parameter Store配置中心

问题 最近需要改造一个Python的Flask项目。需要在这个项目中添加AWS Secrets Manager作为配置中心&#xff0c;主要是数据库相关配置。 前提 得预先在Amazon RDS里面新建好数据库用户和数据库&#xff0c;以AWS Aurora为例子&#xff0c;建库和建用户语句类似如下&#xff1…...

机器学习与深度学习10-支持向量机02

目录 前文回顾6.如何构建SVM7.SVM与多分类问题8.SVM与逻辑回归9.SVM的可扩展性10.SVM的适用性和局限性 前文回顾 上一篇文章链接&#xff1a;地址 6.如何构建SVM 选择合适的核函数和超参数来构建支持向量机&#xff08;SVM&#xff09;模型通常需要一定的经验和实验。以下是…...

《深入解析UART协议及其硬件实现》-- 第二篇:UART硬件架构设计与FPGA实现

第二篇&#xff1a;UART硬件架构设计与FPGA实现 1. 模块化架构设计 1.1 系统级框图与时钟域划分 核心模块划分 &#xff1a; 发送模块&#xff08;TX&#xff09; &#xff1a;负责数据帧组装与串行输出。 接收模块&#xff08;RX&#xff09; &#xff1a;负责串行数据采样与…...