图像识别模型与训练策略
图像预处理
1.需要将图像Resize到相同大小输入到卷积网络中
2.翻转、裁剪、色彩偏移等操作
3.转化为Tensor数据格式
4.对RGB三种颜色通道进行标准化
data_transforms = {'train': transforms.Compose([transforms.Resize([96, 96]),transforms.RandomRotation(45),#随机旋转,-45到45度之间随机选transforms.CenterCrop(64),#从中心开始裁剪transforms.RandomHorizontalFlip(p=0.5),#随机水平翻转 选择一个概率概率transforms.RandomVerticalFlip(p=0.5),#随机垂直翻转transforms.ColorJitter(brightness=0.2, contrast=0.1, saturation=0.1, hue=0.1),#参数1为亮度,参数2为对比度,参数3为饱和度,参数4为色相transforms.RandomGrayscale(p=0.025),#概率转换成灰度率,3通道就是R=G=Btransforms.ToTensor(),transforms.Normalize([0.485, 0.456, 0.406], [0.229, 0.224, 0.225])#均值,标准差]),'valid': transforms.Compose([transforms.Resize([64, 64]),transforms.ToTensor(),transforms.Normalize([0.485, 0.456, 0.406], [0.229, 0.224, 0.225])]),
}
读取数据
将训练集中各个类别文件夹中的数据经过Transforms增强后进行统一读取封装
image_datasets = {x: datasets.ImageFolder(os.path.join(data_dir, x), data_transforms[x]) for x in ['train', 'valid']}
batch_size = 128image_datasets = {x: datasets.ImageFolder(os.path.join(data_dir, x), data_transforms[x]) for x in ['train', 'valid']}
dataloaders = {x: torch.utils.data.DataLoader(image_datasets[x], batch_size=batch_size, shuffle=True) for x in ['train', 'valid']}
dataset_sizes = {x: len(image_datasets[x]) for x in ['train', 'valid']}
class_names = image_datasets['train'].classes
迁移学习
使用官方发布的模型和参数,将参数冻住不更新
def set_parameter_requires_grad(model, feature_extracting):if feature_extracting:for param in model.parameters():param.requires_grad = Falsemodel_ft = models.resnet18()#18层的能快点,条件好点的也可以选152
model_ft
修改输出层
def initialize_model(model_name, num_classes, feature_extract, use_pretrained=True):model_ft = models.resnet18(pretrained=use_pretrained)set_parameter_requires_grad(model_ft, feature_extract)num_ftrs = model_ft.fc.in_featuresmodel_ft.fc = nn.Linear(num_ftrs, 102)#类别数自己根据自己任务来input_size = 64#输入大小根据自己配置来return model_ft, input_size
更新输出层参数
model_ft, input_size = initialize_model(model_name, 102, feature_extract, use_pretrained=True)#GPU还是CPU计算
model_ft = model_ft.to(device)# 模型保存,名字自己起
filename='checkpoint.pth'# 是否训练所有层
params_to_update = model_ft.parameters()
print("Params to learn:")
if feature_extract:params_to_update = []for name,param in model_ft.named_parameters():if param.requires_grad == True:params_to_update.append(param)print("\t",name)
else:for name,param in model_ft.named_parameters():if param.requires_grad == True:print("\t",name)
优化器设置
optimizer_ft = optim.Adam(params_to_update, lr=1e-2)#要训练啥参数,你来定
scheduler = optim.lr_scheduler.StepLR(optimizer_ft, step_size=10, gamma=0.1)#学习率每7个epoch衰减成原来的1/10
criterion = nn.CrossEntropyLoss()
训练策略
def train_model(model, dataloaders, criterion, optimizer, num_epochs=25,filename='best.pt'):#咱们要算时间的since = time.time()#也要记录最好的那一次best_acc = 0#模型也得放到你的CPU或者GPUmodel.to(device)#训练过程中打印一堆损失和指标val_acc_history = []train_acc_history = []train_losses = []valid_losses = []#学习率LRs = [optimizer.param_groups[0]['lr']]#最好的那次模型,后续会变的,先初始化best_model_wts = copy.deepcopy(model.state_dict())#一个个epoch来遍历for epoch in range(num_epochs):print('Epoch {}/{}'.format(epoch, num_epochs - 1))print('-' * 10)# 训练和验证for phase in ['train', 'valid']:if phase == 'train':model.train() # 训练else:model.eval() # 验证running_loss = 0.0running_corrects = 0# 把数据都取个遍for inputs, labels in dataloaders[phase]:inputs = inputs.to(device)#放到你的CPU或GPUlabels = labels.to(device)# 清零optimizer.zero_grad()# 只有训练的时候计算和更新梯度outputs = model(inputs)loss = criterion(outputs, labels)_, preds = torch.max(outputs, 1)# 训练阶段更新权重if phase == 'train':loss.backward()optimizer.step()# 计算损失running_loss += loss.item() * inputs.size(0)#0表示batch那个维度running_corrects += torch.sum(preds == labels.data)#预测结果最大的和真实值是否一致epoch_loss = running_loss / len(dataloaders[phase].dataset)#算平均epoch_acc = running_corrects.double() / len(dataloaders[phase].dataset)time_elapsed = time.time() - since#一个epoch我浪费了多少时间print('Time elapsed {:.0f}m {:.0f}s'.format(time_elapsed // 60, time_elapsed % 60))print('{} Loss: {:.4f} Acc: {:.4f}'.format(phase, epoch_loss, epoch_acc))# 得到最好那次的模型if phase == 'valid' and epoch_acc > best_acc:best_acc = epoch_accbest_model_wts = copy.deepcopy(model.state_dict())state = {'state_dict': model.state_dict(),#字典里key就是各层的名字,值就是训练好的权重'best_acc': best_acc,'optimizer' : optimizer.state_dict(),}torch.save(state, filename)if phase == 'valid':val_acc_history.append(epoch_acc)valid_losses.append(epoch_loss)#scheduler.step(epoch_loss)#学习率衰减if phase == 'train':train_acc_history.append(epoch_acc)train_losses.append(epoch_loss)print('Optimizer learning rate : {:.7f}'.format(optimizer.param_groups[0]['lr']))LRs.append(optimizer.param_groups[0]['lr'])print()scheduler.step()#学习率衰减time_elapsed = time.time() - sinceprint('Training complete in {:.0f}m {:.0f}s'.format(time_elapsed // 60, time_elapsed % 60))print('Best val Acc: {:4f}'.format(best_acc))# 训练完后用最好的一次当做模型最终的结果,等着一会测试model.load_state_dict(best_model_wts)return model, val_acc_history, train_acc_history, valid_losses, train_losses, LRs
相关文章:
图像识别模型与训练策略
图像预处理 1.需要将图像Resize到相同大小输入到卷积网络中 2.翻转、裁剪、色彩偏移等操作 3.转化为Tensor数据格式 4.对RGB三种颜色通道进行标准化 data_transforms {train: transforms.Compose([transforms.Resize([96, 96]),transforms.RandomRotation(45),#随机旋转&…...
算法工程师-机器学习面试题总结(3)
FM模型 FM模型与逻辑回归相比有什么优缺点? FM(因子分解机)模型和逻辑回归是两种常见的预测建模方法,它们在一些方面有不同的优缺点。 FM模型的优点: 1. 能够捕获特征之间的交互作用:FM模型通过对特征向量…...
ROS2学习(五)进程内topic高效通信
对ROS2有一定了解后,我们会发现ROS2中节点和ROS1中节点的概率有很大的区别。在ROS1中节点是最小的进程单元。在ROS2中节点与进程和线程的概念完全区分开了。具体区别可以参考 ROS2学习(四)进程,线程与节点的关系。 在ROS2中同一个进程中可能存在多个节点…...
算法-最大数
给定一组非负整数 nums,重新排列每个数的顺序(每个数不可拆分)使之组成一个最大的整数。 注意:输出结果可能非常大,所以你需要返回一个字符串而不是整数。 输入:nums [10,2] 输出:"210&…...
Spark中使用RDD算子GroupBy做词频统计的方法
测试文件及环境 测试文件在本地D://tmp/spark.txt,Spark采用Local模式运行,Spark版本3.2.0,Scala版本2.12,集成idea开发环境。 hello world java world java java实验代码 import org.apache.spark.rdd.RDD import org.apache.…...
如何使用Kafka构建事件驱动的架构
事件驱动的架构(EDA)是一种软件设计模式,它关注事件的生成、检测和使用,以支持高效和可扩展的系统。在EDA中,事件是组件之间通信的主要手段,允许它们实时交互和响应更改。这种架构促进了松散耦合、可扩展性和响应性,使…...
ES6 解构赋值
解构赋值 解构赋值是一种在编程中常见且方便的语法特性,它可以让你从数组或对象中快速提取数据,并将数据赋值给变量。在许多编程语言中都有类似的特性。 在 JavaScript 中,解构赋值使得从数组或对象中提取数据变得简单。它可以用于数组和对…...
HTML5注册页面
分析 注册界面实际上是一个表格(对齐),一行有两个单元格。 代码 <!DOCTYPE html> <html lang"en"><head><meta charset"UTF-8" /><meta name"viewport" content"widthdevic…...
python中的JSON模块详解
简介 JSON(JavaScript Object Notation) 是一种轻量级的数据交换格式,它使得人们很容易的进行阅读和编写 同时也方便了机器进行解析和生成。适用于进行数据交互的场景,比如网站前台与后台之间的数据交互 网址 官方文档 json — JSON encoder and dec…...
Syncfusion Essential Edit for WPF Crack
Syncfusion Essential Edit for WPF Crack 在任何WPF应用程序中启用语法高亮显示。 Syncfusion Essential Edit for WPF是一款具有所有基本功能的编辑器,如文本编辑、剪切、复制和粘贴。它允许用户从各种文件格式打开文件并将其保存为各种文件格式。Syncfusion Esse…...
机器学习深度学习——卷积神经网络(LeNet)
👨🎓作者简介:一位即将上大四,正专攻机器学习的保研er 🌌上期文章:机器学习&&深度学习——池化层 📚订阅专栏:机器学习&&深度学习 希望文章对你们有所帮助 卷积神…...
Pytorch Tutorial【Chapter 2. Autograd】
Pytorch Tutorial 文章目录 Pytorch TutorialChapter 2. Autograd1. Review Matrix Calculus1.1 Definition向量对向量求导1.2 Definition标量对向量求导1.3 Definition标量对矩阵求导 2.关于autograd的说明3. grad的计算3.1 Manual手动计算3.2 backward()自动计算 Reference C…...
Python第三方库国内镜像下载地址
Python第三方库国内镜像下载地址 一、清华大学二、中国科技大学三、安装方法 一、清华大学 https://pypi.tuna.tsinghua.edu.cn/simple 二、中国科技大学 https://pypi.mirrors.ustc.edu.cn/simple 三、安装方法 例如 pyhook3 插件的安装方法,执行下面命令安装…...
从浏览器输入url到页面加载(七)服务端机器一般部署在哪里
前言 上一节,我们说到了CDN和路由器的关系,说到了公有地址,说到了通信线路服务,这一节跳过那些看不懂的深层知识,直接开始说web服务器。 1. 服务端机器为什么不部署在公司内部 记得在之前的一段时间里,公…...
Pytorch深度学习-----神经网络之Sequential的详细使用及实战详解
系列文章目录 PyTorch深度学习——Anaconda和PyTorch安装 Pytorch深度学习-----数据模块Dataset类 Pytorch深度学习------TensorBoard的使用 Pytorch深度学习------Torchvision中Transforms的使用(ToTensor,Normalize,Resize ,Co…...
安全基础 --- https详解 + 数组(js)
CIA三属性:完整性(Confidentiality)、保密性(Integrity)、可用性(Availability),也称信息安全三要素。 https 核心技术:用非对称加密传输对称加密的密钥,然后…...
vue加载大量数据优化
在Vue中加载大量数据并形成列表时,可以通过以下方法来优化性能: 分页加载:不要一次性加载所有的数据,而是分批加载数据,每次只加载当前页需要显示的数据量。可以使用第三方库如vue-infinite-loading来实现无限滚动加载…...
WebRTC 之音视频同步
在网络视频会议中, 我们常会遇到音视频不同步的问题, 我们有一个专有名词 lip-sync 唇同步来描述这类问题,当我们看到人的嘴唇动作与听到的声音对不上的时候,不同步的问题就出现了 而在线会议中, 听见清晰的声音是优先…...
kubernetes基于helm部署gitlab-runner
kubernetes基于helm部署gitlab-runner 这篇博文介绍如何在 Kubernetes 中使用helm部署 GitLab-runner。 先决条件: 已运行的 Kubernetes 集群已运行的 gitlab 实例 项目地址:https://gitlab.com/gitlab-org/charts/gitlab-runner 官方文档ÿ…...
深度学习和OpenCV的对象检测(MobileNet SSD图像识别)
基于深度学习的对象检测时,我们主要分享以下三种主要的对象检测方法: Faster R-CNN(后期会来学习分享)你只看一次(YOLO,最新版本YOLO3,后期我们会分享)单发探测器(SSD,本节介绍,若你的电脑配置比较低,此方法比较适合R-CNN是使用深度学习进行物体检测的训练模型; 然而,…...
深度学习在微纳光子学中的应用
深度学习在微纳光子学中的主要应用方向 深度学习与微纳光子学的结合主要集中在以下几个方向: 逆向设计 通过神经网络快速预测微纳结构的光学响应,替代传统耗时的数值模拟方法。例如设计超表面、光子晶体等结构。 特征提取与优化 从复杂的光学数据中自…...
MPNet:旋转机械轻量化故障诊断模型详解python代码复现
目录 一、问题背景与挑战 二、MPNet核心架构 2.1 多分支特征融合模块(MBFM) 2.2 残差注意力金字塔模块(RAPM) 2.2.1 空间金字塔注意力(SPA) 2.2.2 金字塔残差块(PRBlock) 2.3 分类器设计 三、关键技术突破 3.1 多尺度特征融合 3.2 轻量化设计策略 3.3 抗噪声…...
CVPR 2025 MIMO: 支持视觉指代和像素grounding 的医学视觉语言模型
CVPR 2025 | MIMO:支持视觉指代和像素对齐的医学视觉语言模型 论文信息 标题:MIMO: A medical vision language model with visual referring multimodal input and pixel grounding multimodal output作者:Yanyuan Chen, Dexuan Xu, Yu Hu…...
Prompt Tuning、P-Tuning、Prefix Tuning的区别
一、Prompt Tuning、P-Tuning、Prefix Tuning的区别 1. Prompt Tuning(提示调优) 核心思想:固定预训练模型参数,仅学习额外的连续提示向量(通常是嵌入层的一部分)。实现方式:在输入文本前添加可训练的连续向量(软提示),模型只更新这些提示参数。优势:参数量少(仅提…...
Objective-C常用命名规范总结
【OC】常用命名规范总结 文章目录 【OC】常用命名规范总结1.类名(Class Name)2.协议名(Protocol Name)3.方法名(Method Name)4.属性名(Property Name)5.局部变量/实例变量(Local / Instance Variables&…...
【机器视觉】单目测距——运动结构恢复
ps:图是随便找的,为了凑个封面 前言 在前面对光流法进行进一步改进,希望将2D光流推广至3D场景流时,发现2D转3D过程中存在尺度歧义问题,需要补全摄像头拍摄图像中缺失的深度信息,否则解空间不收敛…...
全球首个30米分辨率湿地数据集(2000—2022)
数据简介 今天我们分享的数据是全球30米分辨率湿地数据集,包含8种湿地亚类,该数据以0.5X0.5的瓦片存储,我们整理了所有属于中国的瓦片名称与其对应省份,方便大家研究使用。 该数据集作为全球首个30米分辨率、覆盖2000–2022年时间…...
页面渲染流程与性能优化
页面渲染流程与性能优化详解(完整版) 一、现代浏览器渲染流程(详细说明) 1. 构建DOM树 浏览器接收到HTML文档后,会逐步解析并构建DOM(Document Object Model)树。具体过程如下: (…...
【RockeMQ】第2节|RocketMQ快速实战以及核⼼概念详解(二)
升级Dledger高可用集群 一、主从架构的不足与Dledger的定位 主从架构缺陷 数据备份依赖Slave节点,但无自动故障转移能力,Master宕机后需人工切换,期间消息可能无法读取。Slave仅存储数据,无法主动升级为Master响应请求ÿ…...
k8s业务程序联调工具-KtConnect
概述 原理 工具作用是建立了一个从本地到集群的单向VPN,根据VPN原理,打通两个内网必然需要借助一个公共中继节点,ktconnect工具巧妙的利用k8s原生的portforward能力,简化了建立连接的过程,apiserver间接起到了中继节…...
