pytorch实战-图像分类(二)(模型训练及验证)(基于迁移学习(理解+代码))
目录
1.迁移学习概念
2.数据预处理
3.训练模型(基于迁移学习)
3.1选择网络,这里用resnet
3.2如果用GPU训练,需要加入以下代码
3.3卷积层冻结模块
3.4加载resnet152模
3.5解释initialize_model函数
3.6迁移学习网络搭建
3.7优化器
3.8训练模块(可以理解为主函数)
3.9开始训练
3.10微调
4.测试模型
4.1加载训练好的模型
4.2测试数据预处理
4.3数据展示
4.4提取测试数据集
4.5计算提取数据集的预测结果
4.6展示预测结果
参考文献
1.迁移学习概念
先说一下深度学习常见的问题:
1.数据集不够,通常用数据增强解决。
2.参数难以确定,训练时间长,这就需要用迁移学习来解决
什么叫迁移学习呢:比方说有一个对100w的自行车数据集,并用VGG模型训练好的网络,而此时你想训练一个1w自行车数据集(虽然对象一样,但采集的数据会不同),也用VGG模型进行训练,你发现,你们数据集的对象一样,选用的网络模型一样,此时在初始化自己模型权重(就是卷积层,池化层和全连接层的参数)时,可以用人家训练好的模型参数(如果不这样就需要随机初始化模型权重),这样做可以节省大量寻找最优参数的时间,又可以保证参数的准确。
总结:迁移学习就是用别人的东西训练自己的东西,但要注意,为了使用别人的模型参数,要保证自己的数据对象、网络结构、输入和输出数据的结构和别人相同。比方说,别人识别狗,你不能识别 猫,别人用VGG你不能用resnet,别人输入和输入图像大小是224×224.你不能是256×256。
进一步理解迁移学习的使用1:看下图最大的红框,表示卷积层,当用别人的模型时,对卷积层的两种处理方式。
A:作为自己模型权重的初始化参数。
B:冻结卷积层网络,意思是直接用别人的参数,不再更新。冻结卷积层网络又分几种情况。
B1:当数据量小时,冻结第二大红框表示的卷积层,剩下卷积层进行更新。因为数据量小时,容易过拟合,直接用别人呢参数最好。
B2:当数据量中等时冻结最小红框表示的卷积层,剩下的卷积层进行更行。
B3:当数据量足够大时,不冻结卷积层,用A的方法,只作为自己模型权重的初始化参数。数据量大时,虽然对象一样,但毕竟数据不同,会有一定差异,更新参数是最优选择。
进一步理解迁移学习的使用2:说完卷积层,在说一下全连接层,必须要注意不管卷积层选A还是B,全连接层都是要更新的,原因在于,别人模型进行图像分类可能是进行1000个分类,而你只进行100或者999个分类,那么全连接层的参数肯定是不同的。
2.数据预处理
上接该文:pytorch实战-图像分类(一)(数据预处理)
3.训练模型(基于迁移学习)
3.1选择网络,这里用resnet
model_name = 'resnet' #可选的比较多 ['resnet', 'alexnet', 'vgg', 'squeezenet', 'densenet', 'inception']
#是否用人家训练好的特征来做
feature_extract = True
3.2如果用GPU训练,需要加入以下代码
# 是否用GPU训练
train_on_gpu = torch.cuda.is_available()if not train_on_gpu:print('CUDA is not available. Training on CPU ...')
else:print('CUDA is available! Training on GPU ...')device = torch.device("cuda:0" if torch.cuda.is_available() else "cpu")
3.3卷积层冻结模块
def set_parameter_requires_grad(model, feature_extracting):if feature_extracting:for param in model.parameters():param.requires_grad = False
3.4加载resnet152模
注意:resnet152模型就是别人的模型。
model_ft = models.resnet152()
model_ft
3.5解释initialize_model函数
本小节只是截取pytorch官网的一个例子,用initialize_model说明在pytoch中迁移学习怎么使用,不属于本文代码
具体操作如下:
1.下载别人的模型参数,这里下载restnet152模型
2.选择需要冻结的卷积层
3.改变全连接层的输出个数,这里将1000改为102
def initialize_model(model_name, num_classes, feature_extract, use_pretrained=True):# 选择合适的模型,不同模型的初始化方法稍微有点区别model_ft = Noneinput_size = 0if model_name == "resnet":""" Resnet152"""model_ft = models.resnet152(pretrained=use_pretrained) #下载resnet152模型set_parameter_requires_grad(model_ft, feature_extract) #选择冻结哪部分卷积层num_ftrs = model_ft.fc.in_features #全连接层的输入比方说全连接层是2048×1000,这就是2048.model_ft.fc = nn.Sequential(nn.Linear(num_ftrs, 102),nn.LogSoftmax(dim=1)) #原resnet152的全连接层输出是1000,自己模型需要的输出是102,进行改动。input_size = 224return model_ft, input_size
3.6迁移学习网络搭建
model_ft, input_size = initialize_model(model_name, 102, feature_extract, use_pretrained=True)#GPU计算
model_ft = model_ft.to(device)# 模型保存
filename='checkpoint.pth'# 是否训练所有层
params_to_update = model_ft.parameters()
print("Params to learn:")
if feature_extract:params_to_update = []for name,param in model_ft.named_parameters():if param.requires_grad == True:params_to_update.append(param)print("\t",name)
else:for name,param in model_ft.named_parameters():if param.requires_grad == True:print("\t",name)
3.7优化器
就是用该方法更新模型参数
# 优化器设置
optimizer_ft = optim.Adam(params_to_update, lr=1e-2)
scheduler = optim.lr_scheduler.StepLR(optimizer_ft, step_size=7, gamma=0.1)#学习率每7个epoch衰减成原来的1/10
#最后一层已经LogSoftmax()了,所以不能nn.CrossEntropyLoss()来计算了,nn.CrossEntropyLoss()相当于logSoftmax()和nn.NLLLoss()整合
criterion = nn.NLLLoss()
3.8训练模块(可以理解为主函数)
def train_model(model, dataloaders, criterion, optimizer, num_epochs=25, is_inception=False,filename=filename):since = time.time() #best_acc = 0"""checkpoint = torch.load(filename)best_acc = checkpoint['best_acc']model.load_state_dict(checkpoint['state_dict'])optimizer.load_state_dict(checkpoint['optimizer'])model.class_to_idx = checkpoint['mapping']"""model.to(device)val_acc_history = []train_acc_history = []train_losses = []valid_losses = []LRs = [optimizer.param_groups[0]['lr']]best_model_wts = copy.deepcopy(model.state_dict())for epoch in range(num_epochs):print('Epoch {}/{}'.format(epoch, num_epochs - 1))print('-' * 10)# 训练和验证for phase in ['train', 'valid']:if phase == 'train':model.train() # 训练else:model.eval() # 验证running_loss = 0.0running_corrects = 0# 把数据都取个遍for inputs, labels in dataloaders[phase]:inputs = inputs.to(device)labels = labels.to(device)# 清零optimizer.zero_grad()# 只有训练的时候计算和更新梯度with torch.set_grad_enabled(phase == 'train'):if is_inception and phase == 'train':outputs, aux_outputs = model(inputs)loss1 = criterion(outputs, labels)loss2 = criterion(aux_outputs, labels)loss = loss1 + 0.4*loss2else:#resnet执行的是这里outputs = model(inputs)loss = criterion(outputs, labels)_, preds = torch.max(outputs, 1)# 训练阶段更新权重if phase == 'train':loss.backward()optimizer.step()# 计算损失running_loss += loss.item() * inputs.size(0)running_corrects += torch.sum(preds == labels.data)epoch_loss = running_loss / len(dataloaders[phase].dataset)epoch_acc = running_corrects.double() / len(dataloaders[phase].dataset)time_elapsed = time.time() - sinceprint('Time elapsed {:.0f}m {:.0f}s'.format(time_elapsed // 60, time_elapsed % 60))print('{} Loss: {:.4f} Acc: {:.4f}'.format(phase, epoch_loss, epoch_acc))# 得到最好那次的模型if phase == 'valid' and epoch_acc > best_acc:best_acc = epoch_accbest_model_wts = copy.deepcopy(model.state_dict())state = {'state_dict': model.state_dict(),'best_acc': best_acc,'optimizer' : optimizer.state_dict(),}torch.save(state, filename)if phase == 'valid':val_acc_history.append(epoch_acc)valid_losses.append(epoch_loss)scheduler.step(epoch_loss)if phase == 'train':train_acc_history.append(epoch_acc)train_losses.append(epoch_loss)print('Optimizer learning rate : {:.7f}'.format(optimizer.param_groups[0]['lr']))LRs.append(optimizer.param_groups[0]['lr'])print()time_elapsed = time.time() - sinceprint('Training complete in {:.0f}m {:.0f}s'.format(time_elapsed // 60, time_elapsed % 60))print('Best val Acc: {:4f}'.format(best_acc))# 训练完后用最好的一次当做模型最终的结果model.load_state_dict(best_model_wts)return model, val_acc_history, train_acc_history, valid_losses, train_losses, LRs
3.9开始训练
model_ft, val_acc_history, train_acc_history, valid_losses, train_losses, LRs = train_model(model_ft, dataloaders, criterion, optimizer_ft, num_epochs=20, is_inception=(model_name=="inception"))
3.10微调
在2.9中得到的模型,是冻结了卷积层,只训练了全连接层,所以此时希望在此基础上再对卷积层进行训练。
for param in model_ft.parameters():param.requires_grad = True# 再继续训练所有的参数,学习率调小一点
optimizer = optim.Adam(params_to_update, lr=1e-4)
scheduler = optim.lr_scheduler.StepLR(optimizer_ft, step_size=7, gamma=0.1)# 损失函数
criterion = nn.NLLLoss()# Load the checkpoint,加载自己的模型checkpoint = torch.load(filename)
best_acc = checkpoint['best_acc']
model_ft.load_state_dict(checkpoint['state_dict'])
optimizer.load_state_dict(checkpoint['optimizer'])
#model_ft.class_to_idx = checkpoint['mapping']model_ft, val_acc_history, train_acc_history, valid_losses, train_losses, LRs = train_model(model_ft, dataloaders, criterion, optimizer, num_epochs=10, is_inception=(model_name=="inception"))
4.测试模型
4.1加载训练好的模型
model_ft, input_size = initialize_model(model_name, 102, feature_extract, use_pretrained=True)# GPU模式
model_ft = model_ft.to(device)# 保存文件的名字
filename='seriouscheckpoint.pth'# 加载模型
checkpoint = torch.load(filename)
best_acc = checkpoint['best_acc']
model_ft.load_state_dict(checkpoint['state_dict'])
4.2测试数据预处理
1.测试数据处理方法需要跟训练时一直才可以
2.crop操作的目的是保证输入的大小是一致的
3.标准化操作也是必须的,用跟训练数据相同的mean和std,但是需要注意一点训练数据是在0-1上进行标准化,所以测试数据也需要先归一化
4.PyTorch中颜色通道是第一个维度,跟很多工具包都不一样,需要转换
def process_image(image_path):# 读取测试数据img = Image.open(image_path)# Resize,thumbnail方法只能进行缩小,所以进行了判断if img.size[0] > img.size[1]:img.thumbnail((10000, 256))else:img.thumbnail((256, 10000))# Crop操作left_margin = (img.width-224)/2bottom_margin = (img.height-224)/2right_margin = left_margin + 224top_margin = bottom_margin + 224img = img.crop((left_margin, bottom_margin, right_margin, top_margin))# 相同的预处理方法img = np.array(img)/255mean = np.array([0.485, 0.456, 0.406]) #provided meanstd = np.array([0.229, 0.224, 0.225]) #provided stdimg = (img - mean)/std# 注意颜色通道应该放在第一个位置img = img.transpose((2, 0, 1))return img
4.3数据展示
def imshow(image, ax=None, title=None):"""展示数据"""if ax is None:fig, ax = plt.subplots()# 颜色通道还原image = np.array(image).transpose((1, 2, 0))# 预处理还原mean = np.array([0.485, 0.456, 0.406])std = np.array([0.229, 0.224, 0.225])image = std * image + meanimage = np.clip(image, 0, 1)ax.imshow(image)ax.set_title(title)return ax
4.4提取测试数据集
# 得到一个batch的测试数据
dataiter = iter(dataloaders['valid'])
images, labels = dataiter.next()model_ft.eval()if train_on_gpu:output = model_ft(images.cuda())
else:output = model_ft(images)
4.5计算提取数据集的预测结果
_, preds_tensor = torch.max(output, 1)preds = np.squeeze(preds_tensor.numpy()) if not train_on_gpu else np.squeeze(preds_tensor.cpu().numpy())
preds
4.6展示预测结果
fig=plt.figure(figsize=(20, 20))
columns =4
rows = 2for idx in range (columns*rows):ax = fig.add_subplot(rows, columns, idx+1, xticks=[], yticks=[])plt.imshow(im_convert(images[idx]))ax.set_title("{} ({})".format(cat_to_name[str(preds[idx])], cat_to_name[str(labels[idx].item())]),color=("green" if cat_to_name[str(preds[idx])]==cat_to_name[str(labels[idx].item())] else "red"))
plt.show()
参考文献
1.6-训练结果与模型保存_哔哩哔哩_bilibili
相关文章:

pytorch实战-图像分类(二)(模型训练及验证)(基于迁移学习(理解+代码))
目录 1.迁移学习概念 2.数据预处理 3.训练模型(基于迁移学习) 3.1选择网络,这里用resnet 3.2如果用GPU训练,需要加入以下代码 3.3卷积层冻结模块 3.4加载resnet152模 3.5解释initialize_model函数 3.6迁移学习网络搭建 3.…...

b 树和 b+树的理解
项目场景: 图灵奖获得者(Niklaus Wirth )说过: 程序 数据结构 算法, 也就说我们无时无刻 都在和数据结构打交道。 只是作为 Java 开发,由于技术体系的成熟度较高,使得大部分人认为࿱…...
正则表达式 —— Awk
Awk awk:文本三剑客之一,是功能最强大的文本工具 awk也是按行来进行操作,对行操作完之后,可以根据指定命令来对行取列 awk的分隔符,默认分隔符是空格或tab键,多个空格会压缩成一个 awk的用法 awk的格式…...

国芯新作 | 四核Cortex-A53@1.4GHz,仅168元起?含税?哇!!!
创龙科技SOM-TLT507是一款基于全志科技T507-H处理器设计的4核ARM Cortex-A53全国产工业核心板,主频高达1.416GHz。核心板CPU、ROM、RAM、电源、晶振等所有元器件均采用国产工业级方案,国产化率100%。 核心板通过邮票孔连接方式引出MIPI CSI、HDMI OUT、…...

【MyBatis】 框架原理
目录 10.3【MyBatis】 框架原理 10.3.1 【MyBatis】 整体架构 10.3.2 【MyBatis】 运行原理 10.4 【MyBatis】 核心组件的生命周期 10.4.1 SqlSessionFactoryBuilder 10.4.2 SqlSessionFactory 10.4.3 SqlSession 10.4.4 Mapper Instances 与 Hibernate 框架相比&#…...

三、线性工作流
再生产的各个环节,正确使用gamma编码及gamma解码,使得最终得到的颜色数据与最初输入的物理数据一致。如果使用gamma空间的贴图,在传给着色器前需要从gamma空间转到线性空间。 如果不在线性空间下进行渲染,会产生的问题:…...

2023华数杯数学建模A题思路 - 隔热材料的结构优化控制研究
# 1 赛题 A 题 隔热材料的结构优化控制研究 新型隔热材料 A 具有优良的隔热特性,在航天、军工、石化、建筑、交通等 高科技领域中有着广泛的应用。 目前,由单根隔热材料 A 纤维编织成的织物,其热导率可以直接测出;但是 单根隔热…...

Zabbix分布式监控Web监控
目录 1 概述2 配置 Web 场景2.1 配置步骤2.2 显示 3 Web 场景步骤3.1 创建新的 Web 场景。3.2 定义场景的步骤3.3 保存配置完成的Web 监控场景。 4 Zabbix-Get的使用 1 概述 您可以使用 Zabbix 对多个网站进行可用性方面监控: 要使用 Web 监控,您需要定…...
PHP从入门到精通—PHP开发入门-PHP概述、PHP开发环境搭建、PHP开发环境搭建、第一个PHP程序、PHP开发流程
每开始学习一门语言,都要了解这门语言和进行开发环境的搭建。同样,学生开始PHP学习之前,首先要了解这门语言的历史、语言优势等内容以及了解开发环境的搭建。 PHP概述 认识PHP PHP最初是由Rasmus Lerdorf于1994年为了维护个人网页而编写的一…...
【LeetCode-中等】722. 删除注释
题目链接 722. 删除注释 标签 字符串 步骤 Step1. 先将source合并为一个字符串进行处理,中间补上’\n’,方便后续确定注释开始、结束位置。 string combined; for (auto str : source) {combined str "\n"; }Step2. 定义数组 toDel&am…...
rust里如何判断字符串是否相等呢?
在 Rust 中,有几种方法可以判断字符串是否相等。下面是其中几种常见的方法: 使用 运算符:可以直接使用 运算符比较两个字符串是否相等。例如: fn main() {let str1 "hello";let str2 "world";if str1 …...

python基本知识学习
一、输出语句 在控制台输出Hello,World! print("Hello,World!") 二、注释 单行注释:以#开头 # print("你好") 多行注释: 选中要注释的代码Ctrl/三单引号三双引号 # print("你好") # a1 # a2 print("Hello,World!&…...

vue3和typescript_组件
1 components下新建myComponent.vue 2 页面中引入组件,传入值,并且绑定事件函数。 3...

Qt+联想电脑管家
1.自定义按钮类 效果: (1)仅当未选中,未悬浮时 (2)其他三种情况,均如图 #ifndef BTN_H #define BTN_H#include <QPushButton> class btn : public QPushButton {Q_OBJECT public:btn(QWidget * parent nullptr);void set_normal_icon(…...

论文阅读-BotPercent: Estimating Twitter Bot Populations from Groups to Crowds
目录 摘要 引言 方法 数据集 BotPercent架构 实验结果 活跃用户中的Bot数量 Bot Population among Comment Sections Bot Participation in Content Moderation Votes Bot Population in Different Countries’ Politics 论文链接:https://arxiv.org/pdf/23…...

用于永磁同步电机驱动器的自适应SDRE非线性无传感器速度控制(MatlabSimulink实现)
💥💥💞💞欢迎来到本博客❤️❤️💥💥 🏆博主优势:🌞🌞🌞博客内容尽量做到思维缜密,逻辑清晰,为了方便读者。 ⛳️座右铭&a…...

Spring Cloud+Spring Boot+Mybatis+uniapp+前后端分离实现知识付费平台免费搭建 qt
Java版知识付费源码 Spring CloudSpring BootMybatisuniapp前后端分离实现知识付费平台 提供职业教育、企业培训、知识付费系统搭建服务。系统功能包含:录播课、直播课、题库、营销、公司组织架构、员工入职培训等。 提供私有化部署,免费售…...

删除注释(力扣)
删除注释 题目 给一个 C 程序,删除程序中的注释。这个程序source是一个数组,其中source[i]表示第 i 行源码。 这表示每行源码由 ‘\n’ 分隔。 在 C 中有两种注释风格,行内注释和块注释。 字符串// 表示行注释,表示//和其右侧…...

阿里云AK创建
要在阿里云上创建 Access Key(AK),您需要按照以下步骤进行操作: 登录到阿里云控制台([https://www.aliyun.com/?utm_contentse_1014243503))。 点击右上方的主账号,点击“AccessKey管理”。 …...

OC与Swift的相互调用
OC调用Swift方法 1、在 Build Settings 搜索 Packaging ,设置 Defines Module 为 YES 2、新建 LottieBridge.swift 文件,自动生成桥 ProductName-Bridging-Header.h 3、在 LottieBridge.swift 中,定义Swift类继承于OC类,声明 obj…...
云原生核心技术 (7/12): K8s 核心概念白话解读(上):Pod 和 Deployment 究竟是什么?
大家好,欢迎来到《云原生核心技术》系列的第七篇! 在上一篇,我们成功地使用 Minikube 或 kind 在自己的电脑上搭建起了一个迷你但功能完备的 Kubernetes 集群。现在,我们就像一个拥有了一块崭新数字土地的农场主,是时…...

Appium+python自动化(十六)- ADB命令
简介 Android 调试桥(adb)是多种用途的工具,该工具可以帮助你你管理设备或模拟器 的状态。 adb ( Android Debug Bridge)是一个通用命令行工具,其允许您与模拟器实例或连接的 Android 设备进行通信。它可为各种设备操作提供便利,如安装和调试…...

华为云Flexus+DeepSeek征文|DeepSeek-V3/R1 商用服务开通全流程与本地部署搭建
华为云FlexusDeepSeek征文|DeepSeek-V3/R1 商用服务开通全流程与本地部署搭建 前言 如今大模型其性能出色,华为云 ModelArts Studio_MaaS大模型即服务平台华为云内置了大模型,能助力我们轻松驾驭 DeepSeek-V3/R1,本文中将分享如何…...

视觉slam十四讲实践部分记录——ch2、ch3
ch2 一、使用g++编译.cpp为可执行文件并运行(P30) g++ helloSLAM.cpp ./a.out运行 二、使用cmake编译 mkdir build cd build cmake .. makeCMakeCache.txt 文件仍然指向旧的目录。这表明在源代码目录中可能还存在旧的 CMakeCache.txt 文件,或者在构建过程中仍然引用了旧的路…...

Linux nano命令的基本使用
参考资料 GNU nanoを使いこなすnano基础 目录 一. 简介二. 文件打开2.1 普通方式打开文件2.2 只读方式打开文件 三. 文件查看3.1 打开文件时,显示行号3.2 翻页查看 四. 文件编辑4.1 Ctrl K 复制 和 Ctrl U 粘贴4.2 Alt/Esc U 撤回 五. 文件保存与退出5.1 Ctrl …...
解决:Android studio 编译后报错\app\src\main\cpp\CMakeLists.txt‘ to exist
现象: android studio报错: [CXX1409] D:\GitLab\xxxxx\app.cxx\Debug\3f3w4y1i\arm64-v8a\android_gradle_build.json : expected buildFiles file ‘D:\GitLab\xxxxx\app\src\main\cpp\CMakeLists.txt’ to exist 解决: 不要动CMakeLists.…...

【Linux】自动化构建-Make/Makefile
前言 上文我们讲到了Linux中的编译器gcc/g 【Linux】编译器gcc/g及其库的详细介绍-CSDN博客 本来我们将一个对于编译来说很重要的工具:make/makfile 1.背景 在一个工程中源文件不计其数,其按类型、功能、模块分别放在若干个目录中,mak…...

破解路内监管盲区:免布线低位视频桩重塑停车管理新标准
城市路内停车管理常因行道树遮挡、高位设备盲区等问题,导致车牌识别率低、逃费率高,传统模式在复杂路段束手无策。免布线低位视频桩凭借超低视角部署与智能算法,正成为破局关键。该设备安装于车位侧方0.5-0.7米高度,直接规避树枝遮…...

什么是VR全景技术
VR全景技术,全称为虚拟现实全景技术,是通过计算机图像模拟生成三维空间中的虚拟世界,使用户能够在该虚拟世界中进行全方位、无死角的观察和交互的技术。VR全景技术模拟人在真实空间中的视觉体验,结合图文、3D、音视频等多媒体元素…...
Modbus RTU与Modbus TCP详解指南
目录 1. Modbus协议基础 1.1 什么是Modbus? 1.2 Modbus协议历史 1.3 Modbus协议族 1.4 Modbus通信模型 🎭 主从架构 🔄 请求响应模式 2. Modbus RTU详解 2.1 RTU是什么? 2.2 RTU物理层 🔌 连接方式 ⚡ 通信参数 2.3 RTU数据帧格式 📦 帧结构详解 🔍…...