pytorch实战-图像分类(二)(模型训练及验证)(基于迁移学习(理解+代码))
目录
1.迁移学习概念
2.数据预处理
3.训练模型(基于迁移学习)
3.1选择网络,这里用resnet
3.2如果用GPU训练,需要加入以下代码
3.3卷积层冻结模块
3.4加载resnet152模
3.5解释initialize_model函数
3.6迁移学习网络搭建
3.7优化器
3.8训练模块(可以理解为主函数)
3.9开始训练
3.10微调
4.测试模型
4.1加载训练好的模型
4.2测试数据预处理
4.3数据展示
4.4提取测试数据集
4.5计算提取数据集的预测结果
4.6展示预测结果
参考文献
1.迁移学习概念
先说一下深度学习常见的问题:
1.数据集不够,通常用数据增强解决。
2.参数难以确定,训练时间长,这就需要用迁移学习来解决
什么叫迁移学习呢:比方说有一个对100w的自行车数据集,并用VGG模型训练好的网络,而此时你想训练一个1w自行车数据集(虽然对象一样,但采集的数据会不同),也用VGG模型进行训练,你发现,你们数据集的对象一样,选用的网络模型一样,此时在初始化自己模型权重(就是卷积层,池化层和全连接层的参数)时,可以用人家训练好的模型参数(如果不这样就需要随机初始化模型权重),这样做可以节省大量寻找最优参数的时间,又可以保证参数的准确。
总结:迁移学习就是用别人的东西训练自己的东西,但要注意,为了使用别人的模型参数,要保证自己的数据对象、网络结构、输入和输出数据的结构和别人相同。比方说,别人识别狗,你不能识别 猫,别人用VGG你不能用resnet,别人输入和输入图像大小是224×224.你不能是256×256。
进一步理解迁移学习的使用1:看下图最大的红框,表示卷积层,当用别人的模型时,对卷积层的两种处理方式。
A:作为自己模型权重的初始化参数。
B:冻结卷积层网络,意思是直接用别人的参数,不再更新。冻结卷积层网络又分几种情况。
B1:当数据量小时,冻结第二大红框表示的卷积层,剩下卷积层进行更新。因为数据量小时,容易过拟合,直接用别人呢参数最好。
B2:当数据量中等时冻结最小红框表示的卷积层,剩下的卷积层进行更行。
B3:当数据量足够大时,不冻结卷积层,用A的方法,只作为自己模型权重的初始化参数。数据量大时,虽然对象一样,但毕竟数据不同,会有一定差异,更新参数是最优选择。
进一步理解迁移学习的使用2:说完卷积层,在说一下全连接层,必须要注意不管卷积层选A还是B,全连接层都是要更新的,原因在于,别人模型进行图像分类可能是进行1000个分类,而你只进行100或者999个分类,那么全连接层的参数肯定是不同的。

2.数据预处理
上接该文:pytorch实战-图像分类(一)(数据预处理)
3.训练模型(基于迁移学习)
3.1选择网络,这里用resnet
model_name = 'resnet' #可选的比较多 ['resnet', 'alexnet', 'vgg', 'squeezenet', 'densenet', 'inception']
#是否用人家训练好的特征来做
feature_extract = True
3.2如果用GPU训练,需要加入以下代码
# 是否用GPU训练
train_on_gpu = torch.cuda.is_available()if not train_on_gpu:print('CUDA is not available. Training on CPU ...')
else:print('CUDA is available! Training on GPU ...')device = torch.device("cuda:0" if torch.cuda.is_available() else "cpu")
3.3卷积层冻结模块
def set_parameter_requires_grad(model, feature_extracting):if feature_extracting:for param in model.parameters():param.requires_grad = False
3.4加载resnet152模
注意:resnet152模型就是别人的模型。
model_ft = models.resnet152()
model_ft
3.5解释initialize_model函数
本小节只是截取pytorch官网的一个例子,用initialize_model说明在pytoch中迁移学习怎么使用,不属于本文代码
具体操作如下:
1.下载别人的模型参数,这里下载restnet152模型
2.选择需要冻结的卷积层
3.改变全连接层的输出个数,这里将1000改为102
def initialize_model(model_name, num_classes, feature_extract, use_pretrained=True):# 选择合适的模型,不同模型的初始化方法稍微有点区别model_ft = Noneinput_size = 0if model_name == "resnet":""" Resnet152"""model_ft = models.resnet152(pretrained=use_pretrained) #下载resnet152模型set_parameter_requires_grad(model_ft, feature_extract) #选择冻结哪部分卷积层num_ftrs = model_ft.fc.in_features #全连接层的输入比方说全连接层是2048×1000,这就是2048.model_ft.fc = nn.Sequential(nn.Linear(num_ftrs, 102),nn.LogSoftmax(dim=1)) #原resnet152的全连接层输出是1000,自己模型需要的输出是102,进行改动。input_size = 224return model_ft, input_size
3.6迁移学习网络搭建
model_ft, input_size = initialize_model(model_name, 102, feature_extract, use_pretrained=True)#GPU计算
model_ft = model_ft.to(device)# 模型保存
filename='checkpoint.pth'# 是否训练所有层
params_to_update = model_ft.parameters()
print("Params to learn:")
if feature_extract:params_to_update = []for name,param in model_ft.named_parameters():if param.requires_grad == True:params_to_update.append(param)print("\t",name)
else:for name,param in model_ft.named_parameters():if param.requires_grad == True:print("\t",name)
3.7优化器
就是用该方法更新模型参数
# 优化器设置
optimizer_ft = optim.Adam(params_to_update, lr=1e-2)
scheduler = optim.lr_scheduler.StepLR(optimizer_ft, step_size=7, gamma=0.1)#学习率每7个epoch衰减成原来的1/10
#最后一层已经LogSoftmax()了,所以不能nn.CrossEntropyLoss()来计算了,nn.CrossEntropyLoss()相当于logSoftmax()和nn.NLLLoss()整合
criterion = nn.NLLLoss()
3.8训练模块(可以理解为主函数)
def train_model(model, dataloaders, criterion, optimizer, num_epochs=25, is_inception=False,filename=filename):since = time.time() #best_acc = 0"""checkpoint = torch.load(filename)best_acc = checkpoint['best_acc']model.load_state_dict(checkpoint['state_dict'])optimizer.load_state_dict(checkpoint['optimizer'])model.class_to_idx = checkpoint['mapping']"""model.to(device)val_acc_history = []train_acc_history = []train_losses = []valid_losses = []LRs = [optimizer.param_groups[0]['lr']]best_model_wts = copy.deepcopy(model.state_dict())for epoch in range(num_epochs):print('Epoch {}/{}'.format(epoch, num_epochs - 1))print('-' * 10)# 训练和验证for phase in ['train', 'valid']:if phase == 'train':model.train() # 训练else:model.eval() # 验证running_loss = 0.0running_corrects = 0# 把数据都取个遍for inputs, labels in dataloaders[phase]:inputs = inputs.to(device)labels = labels.to(device)# 清零optimizer.zero_grad()# 只有训练的时候计算和更新梯度with torch.set_grad_enabled(phase == 'train'):if is_inception and phase == 'train':outputs, aux_outputs = model(inputs)loss1 = criterion(outputs, labels)loss2 = criterion(aux_outputs, labels)loss = loss1 + 0.4*loss2else:#resnet执行的是这里outputs = model(inputs)loss = criterion(outputs, labels)_, preds = torch.max(outputs, 1)# 训练阶段更新权重if phase == 'train':loss.backward()optimizer.step()# 计算损失running_loss += loss.item() * inputs.size(0)running_corrects += torch.sum(preds == labels.data)epoch_loss = running_loss / len(dataloaders[phase].dataset)epoch_acc = running_corrects.double() / len(dataloaders[phase].dataset)time_elapsed = time.time() - sinceprint('Time elapsed {:.0f}m {:.0f}s'.format(time_elapsed // 60, time_elapsed % 60))print('{} Loss: {:.4f} Acc: {:.4f}'.format(phase, epoch_loss, epoch_acc))# 得到最好那次的模型if phase == 'valid' and epoch_acc > best_acc:best_acc = epoch_accbest_model_wts = copy.deepcopy(model.state_dict())state = {'state_dict': model.state_dict(),'best_acc': best_acc,'optimizer' : optimizer.state_dict(),}torch.save(state, filename)if phase == 'valid':val_acc_history.append(epoch_acc)valid_losses.append(epoch_loss)scheduler.step(epoch_loss)if phase == 'train':train_acc_history.append(epoch_acc)train_losses.append(epoch_loss)print('Optimizer learning rate : {:.7f}'.format(optimizer.param_groups[0]['lr']))LRs.append(optimizer.param_groups[0]['lr'])print()time_elapsed = time.time() - sinceprint('Training complete in {:.0f}m {:.0f}s'.format(time_elapsed // 60, time_elapsed % 60))print('Best val Acc: {:4f}'.format(best_acc))# 训练完后用最好的一次当做模型最终的结果model.load_state_dict(best_model_wts)return model, val_acc_history, train_acc_history, valid_losses, train_losses, LRs
3.9开始训练
model_ft, val_acc_history, train_acc_history, valid_losses, train_losses, LRs = train_model(model_ft, dataloaders, criterion, optimizer_ft, num_epochs=20, is_inception=(model_name=="inception"))
3.10微调
在2.9中得到的模型,是冻结了卷积层,只训练了全连接层,所以此时希望在此基础上再对卷积层进行训练。
for param in model_ft.parameters():param.requires_grad = True# 再继续训练所有的参数,学习率调小一点
optimizer = optim.Adam(params_to_update, lr=1e-4)
scheduler = optim.lr_scheduler.StepLR(optimizer_ft, step_size=7, gamma=0.1)# 损失函数
criterion = nn.NLLLoss()# Load the checkpoint,加载自己的模型checkpoint = torch.load(filename)
best_acc = checkpoint['best_acc']
model_ft.load_state_dict(checkpoint['state_dict'])
optimizer.load_state_dict(checkpoint['optimizer'])
#model_ft.class_to_idx = checkpoint['mapping']model_ft, val_acc_history, train_acc_history, valid_losses, train_losses, LRs = train_model(model_ft, dataloaders, criterion, optimizer, num_epochs=10, is_inception=(model_name=="inception"))
4.测试模型
4.1加载训练好的模型
model_ft, input_size = initialize_model(model_name, 102, feature_extract, use_pretrained=True)# GPU模式
model_ft = model_ft.to(device)# 保存文件的名字
filename='seriouscheckpoint.pth'# 加载模型
checkpoint = torch.load(filename)
best_acc = checkpoint['best_acc']
model_ft.load_state_dict(checkpoint['state_dict'])
4.2测试数据预处理
1.测试数据处理方法需要跟训练时一直才可以
2.crop操作的目的是保证输入的大小是一致的
3.标准化操作也是必须的,用跟训练数据相同的mean和std,但是需要注意一点训练数据是在0-1上进行标准化,所以测试数据也需要先归一化
4.PyTorch中颜色通道是第一个维度,跟很多工具包都不一样,需要转换
def process_image(image_path):# 读取测试数据img = Image.open(image_path)# Resize,thumbnail方法只能进行缩小,所以进行了判断if img.size[0] > img.size[1]:img.thumbnail((10000, 256))else:img.thumbnail((256, 10000))# Crop操作left_margin = (img.width-224)/2bottom_margin = (img.height-224)/2right_margin = left_margin + 224top_margin = bottom_margin + 224img = img.crop((left_margin, bottom_margin, right_margin, top_margin))# 相同的预处理方法img = np.array(img)/255mean = np.array([0.485, 0.456, 0.406]) #provided meanstd = np.array([0.229, 0.224, 0.225]) #provided stdimg = (img - mean)/std# 注意颜色通道应该放在第一个位置img = img.transpose((2, 0, 1))return img
4.3数据展示
def imshow(image, ax=None, title=None):"""展示数据"""if ax is None:fig, ax = plt.subplots()# 颜色通道还原image = np.array(image).transpose((1, 2, 0))# 预处理还原mean = np.array([0.485, 0.456, 0.406])std = np.array([0.229, 0.224, 0.225])image = std * image + meanimage = np.clip(image, 0, 1)ax.imshow(image)ax.set_title(title)return ax
4.4提取测试数据集
# 得到一个batch的测试数据
dataiter = iter(dataloaders['valid'])
images, labels = dataiter.next()model_ft.eval()if train_on_gpu:output = model_ft(images.cuda())
else:output = model_ft(images)
4.5计算提取数据集的预测结果
_, preds_tensor = torch.max(output, 1)preds = np.squeeze(preds_tensor.numpy()) if not train_on_gpu else np.squeeze(preds_tensor.cpu().numpy())
preds
4.6展示预测结果
fig=plt.figure(figsize=(20, 20))
columns =4
rows = 2for idx in range (columns*rows):ax = fig.add_subplot(rows, columns, idx+1, xticks=[], yticks=[])plt.imshow(im_convert(images[idx]))ax.set_title("{} ({})".format(cat_to_name[str(preds[idx])], cat_to_name[str(labels[idx].item())]),color=("green" if cat_to_name[str(preds[idx])]==cat_to_name[str(labels[idx].item())] else "red"))
plt.show()
参考文献
1.6-训练结果与模型保存_哔哩哔哩_bilibili
相关文章:
pytorch实战-图像分类(二)(模型训练及验证)(基于迁移学习(理解+代码))
目录 1.迁移学习概念 2.数据预处理 3.训练模型(基于迁移学习) 3.1选择网络,这里用resnet 3.2如果用GPU训练,需要加入以下代码 3.3卷积层冻结模块 3.4加载resnet152模 3.5解释initialize_model函数 3.6迁移学习网络搭建 3.…...
b 树和 b+树的理解
项目场景: 图灵奖获得者(Niklaus Wirth )说过: 程序 数据结构 算法, 也就说我们无时无刻 都在和数据结构打交道。 只是作为 Java 开发,由于技术体系的成熟度较高,使得大部分人认为࿱…...
正则表达式 —— Awk
Awk awk:文本三剑客之一,是功能最强大的文本工具 awk也是按行来进行操作,对行操作完之后,可以根据指定命令来对行取列 awk的分隔符,默认分隔符是空格或tab键,多个空格会压缩成一个 awk的用法 awk的格式…...
国芯新作 | 四核Cortex-A53@1.4GHz,仅168元起?含税?哇!!!
创龙科技SOM-TLT507是一款基于全志科技T507-H处理器设计的4核ARM Cortex-A53全国产工业核心板,主频高达1.416GHz。核心板CPU、ROM、RAM、电源、晶振等所有元器件均采用国产工业级方案,国产化率100%。 核心板通过邮票孔连接方式引出MIPI CSI、HDMI OUT、…...
【MyBatis】 框架原理
目录 10.3【MyBatis】 框架原理 10.3.1 【MyBatis】 整体架构 10.3.2 【MyBatis】 运行原理 10.4 【MyBatis】 核心组件的生命周期 10.4.1 SqlSessionFactoryBuilder 10.4.2 SqlSessionFactory 10.4.3 SqlSession 10.4.4 Mapper Instances 与 Hibernate 框架相比&#…...
三、线性工作流
再生产的各个环节,正确使用gamma编码及gamma解码,使得最终得到的颜色数据与最初输入的物理数据一致。如果使用gamma空间的贴图,在传给着色器前需要从gamma空间转到线性空间。 如果不在线性空间下进行渲染,会产生的问题:…...
2023华数杯数学建模A题思路 - 隔热材料的结构优化控制研究
# 1 赛题 A 题 隔热材料的结构优化控制研究 新型隔热材料 A 具有优良的隔热特性,在航天、军工、石化、建筑、交通等 高科技领域中有着广泛的应用。 目前,由单根隔热材料 A 纤维编织成的织物,其热导率可以直接测出;但是 单根隔热…...
Zabbix分布式监控Web监控
目录 1 概述2 配置 Web 场景2.1 配置步骤2.2 显示 3 Web 场景步骤3.1 创建新的 Web 场景。3.2 定义场景的步骤3.3 保存配置完成的Web 监控场景。 4 Zabbix-Get的使用 1 概述 您可以使用 Zabbix 对多个网站进行可用性方面监控: 要使用 Web 监控,您需要定…...
PHP从入门到精通—PHP开发入门-PHP概述、PHP开发环境搭建、PHP开发环境搭建、第一个PHP程序、PHP开发流程
每开始学习一门语言,都要了解这门语言和进行开发环境的搭建。同样,学生开始PHP学习之前,首先要了解这门语言的历史、语言优势等内容以及了解开发环境的搭建。 PHP概述 认识PHP PHP最初是由Rasmus Lerdorf于1994年为了维护个人网页而编写的一…...
【LeetCode-中等】722. 删除注释
题目链接 722. 删除注释 标签 字符串 步骤 Step1. 先将source合并为一个字符串进行处理,中间补上’\n’,方便后续确定注释开始、结束位置。 string combined; for (auto str : source) {combined str "\n"; }Step2. 定义数组 toDel&am…...
rust里如何判断字符串是否相等呢?
在 Rust 中,有几种方法可以判断字符串是否相等。下面是其中几种常见的方法: 使用 运算符:可以直接使用 运算符比较两个字符串是否相等。例如: fn main() {let str1 "hello";let str2 "world";if str1 …...
python基本知识学习
一、输出语句 在控制台输出Hello,World! print("Hello,World!") 二、注释 单行注释:以#开头 # print("你好") 多行注释: 选中要注释的代码Ctrl/三单引号三双引号 # print("你好") # a1 # a2 print("Hello,World!&…...
vue3和typescript_组件
1 components下新建myComponent.vue 2 页面中引入组件,传入值,并且绑定事件函数。 3...
Qt+联想电脑管家
1.自定义按钮类 效果: (1)仅当未选中,未悬浮时 (2)其他三种情况,均如图 #ifndef BTN_H #define BTN_H#include <QPushButton> class btn : public QPushButton {Q_OBJECT public:btn(QWidget * parent nullptr);void set_normal_icon(…...
论文阅读-BotPercent: Estimating Twitter Bot Populations from Groups to Crowds
目录 摘要 引言 方法 数据集 BotPercent架构 实验结果 活跃用户中的Bot数量 Bot Population among Comment Sections Bot Participation in Content Moderation Votes Bot Population in Different Countries’ Politics 论文链接:https://arxiv.org/pdf/23…...
用于永磁同步电机驱动器的自适应SDRE非线性无传感器速度控制(MatlabSimulink实现)
💥💥💞💞欢迎来到本博客❤️❤️💥💥 🏆博主优势:🌞🌞🌞博客内容尽量做到思维缜密,逻辑清晰,为了方便读者。 ⛳️座右铭&a…...
Spring Cloud+Spring Boot+Mybatis+uniapp+前后端分离实现知识付费平台免费搭建 qt
Java版知识付费源码 Spring CloudSpring BootMybatisuniapp前后端分离实现知识付费平台 提供职业教育、企业培训、知识付费系统搭建服务。系统功能包含:录播课、直播课、题库、营销、公司组织架构、员工入职培训等。 提供私有化部署,免费售…...
删除注释(力扣)
删除注释 题目 给一个 C 程序,删除程序中的注释。这个程序source是一个数组,其中source[i]表示第 i 行源码。 这表示每行源码由 ‘\n’ 分隔。 在 C 中有两种注释风格,行内注释和块注释。 字符串// 表示行注释,表示//和其右侧…...
阿里云AK创建
要在阿里云上创建 Access Key(AK),您需要按照以下步骤进行操作: 登录到阿里云控制台([https://www.aliyun.com/?utm_contentse_1014243503))。 点击右上方的主账号,点击“AccessKey管理”。 …...
OC与Swift的相互调用
OC调用Swift方法 1、在 Build Settings 搜索 Packaging ,设置 Defines Module 为 YES 2、新建 LottieBridge.swift 文件,自动生成桥 ProductName-Bridging-Header.h 3、在 LottieBridge.swift 中,定义Swift类继承于OC类,声明 obj…...
测试微信模版消息推送
进入“开发接口管理”--“公众平台测试账号”,无需申请公众账号、可在测试账号中体验并测试微信公众平台所有高级接口。 获取access_token: 自定义模版消息: 关注测试号:扫二维码关注测试号。 发送模版消息: import requests da…...
相机Camera日志实例分析之二:相机Camx【专业模式开启直方图拍照】单帧流程日志详解
【关注我,后续持续新增专题博文,谢谢!!!】 上一篇我们讲了: 这一篇我们开始讲: 目录 一、场景操作步骤 二、日志基础关键字分级如下 三、场景日志如下: 一、场景操作步骤 操作步…...
IGP(Interior Gateway Protocol,内部网关协议)
IGP(Interior Gateway Protocol,内部网关协议) 是一种用于在一个自治系统(AS)内部传递路由信息的路由协议,主要用于在一个组织或机构的内部网络中决定数据包的最佳路径。与用于自治系统之间通信的 EGP&…...
LeetCode - 199. 二叉树的右视图
题目 199. 二叉树的右视图 - 力扣(LeetCode) 思路 右视图是指从树的右侧看,对于每一层,只能看到该层最右边的节点。实现思路是: 使用深度优先搜索(DFS)按照"根-右-左"的顺序遍历树记录每个节点的深度对于…...
MySQL JOIN 表过多的优化思路
当 MySQL 查询涉及大量表 JOIN 时,性能会显著下降。以下是优化思路和简易实现方法: 一、核心优化思路 减少 JOIN 数量 数据冗余:添加必要的冗余字段(如订单表直接存储用户名)合并表:将频繁关联的小表合并成…...
OCR MLLM Evaluation
为什么需要评测体系?——背景与矛盾 能干的事: 看清楚发票、身份证上的字(准确率>90%),速度飞快(眨眼间完成)。干不了的事: 碰到复杂表格(合并单元…...
前端调试HTTP状态码
1xx(信息类状态码) 这类状态码表示临时响应,需要客户端继续处理请求。 100 Continue 服务器已收到请求的初始部分,客户端应继续发送剩余部分。 2xx(成功类状态码) 表示请求已成功被服务器接收、理解并处…...
【免费数据】2005-2019年我国272个地级市的旅游竞争力多指标数据(33个指标)
旅游业是一个城市的重要产业构成。旅游竞争力是一个城市竞争力的重要构成部分。一个城市的旅游竞争力反映了其在旅游市场竞争中的比较优势。 今日我们分享的是2005-2019年我国272个地级市的旅游竞争力多指标数据!该数据集源自2025年4月发表于《地理学报》的论文成果…...
文件上传漏洞防御全攻略
要全面防范文件上传漏洞,需构建多层防御体系,结合技术验证、存储隔离与权限控制: 🔒 一、基础防护层 前端校验(仅辅助) 通过JavaScript限制文件后缀名(白名单)和大小,提…...
汇编语言学习(三)——DoxBox中debug的使用
目录 一、安装DoxBox,并下载汇编工具(MASM文件) 二、debug是什么 三、debug中的命令 一、安装DoxBox,并下载汇编工具(MASM文件) 链接: https://pan.baidu.com/s/1IbyJj-JIkl_oMOJmkKiaGQ?pw…...
