基于迁移学习的手势分类模型训练
1、基本原理介绍
这里介绍的单指模型迁移。一般我们训练模型时,往往会自定义一个模型类,这个类中定义了神经网络的结构,训练时将数据集输入,从0开始训练;而迁移学习中(单指模型迁移策略),是在一个已经有过训练基础的模型上,用自己的数据集,进一步训练,使得这个模型能够完成我们需要的任务。
这么做有有这样几个显而易见的好处:
※ 因为模型之前被训练过,所以初始参数不会是0,这样能够加速模型训练
※ 因为预训练模型(什么是预训练模型下文会讲到)在其他数据集上训练过,而其他数据集往往和我们用的数据集存在一定的区别,所以这可以提高模型的泛化能力
※ 通过迁移学习,可以将来自大规模数据的优势转移到小规模或新任务上,提高模型的表现和效果
2、预训练模型
在进行迁移学习时,我们要先找到一个预训练模型。在分类任务领域,比较流行的如resnet系列、mobilenet系列(更轻量化)、vgg(系列)、efficientnet(系列)等等网络,都是比较常用且容易获得的预训练模型,这些模型都能够通过python直接下载。
而且由于上述模型基本都是在ImageNet这一大规模,多分类类别的数据集上进行过训练的,所以对于简单的二分类等少数类别分类,能有较好的效果。
3、训练流程
迁移学习完整的训练流程和一般搭建神经网络的训练模型的流程基本类似:数据预处理->数据集的切分->加载预训练模型(搭建神经网络)->设置超参数/损失函数/优化器等->训练模型
3.1 模型训练
下面的代码是一个利用mobilenet网络训练得到的手势分类模型,该模型能够较准确的分类不同类别手势。
相关解释已在代码中注释说明。
from torchvision.models import mobilenet_v2
import torch
import torch.nn as nn
import torch.optim as optim
from torch.utils.data import DataLoader
from torchvision.datasets import ImageFolder
from torchvision.transforms import Compose, Resize, CenterCrop, ToTensor, Normalize, RandomHorizontalFlip, RandomRotation# 定义数据预处理和增强器
transform = Compose([RandomHorizontalFlip(), # 随机水平翻转RandomRotation(10), # 随机旋转10度Resize((224, 224)),CenterCrop(224),ToTensor(),Normalize(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225])
])# 加载数据集并应用预处理和增强器
dataset = ImageFolder(root='data', transform=transform)
# 这里由于数据比较少,将所有数据集全部用来训练,得到的模型直接拿来用了,这其实不算是非常规范的操作,仅供参考# 定义网络结构
model = mobilenet_v2(pretrained=True) # 加载预训练模型,也可以试试其他模型,效果差别挺大的
num_ftrs = model.classifier[1].in_features
model.classifier[1] = nn.Linear(num_ftrs, 5) # 假设是5分类问题,具体几分类,改这里的参数就行了# 将模型移动到设备上
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
model = model.to(device)# 定义优化器和损失函数
optimizer = optim.SGD(model.parameters(), lr=0.001, momentum=0.9, weight_decay=1e-5)
criterion = nn.CrossEntropyLoss()# 定义训练循环
def train_model(model, criterion, optimizer, num_epochs, train_loader):for epoch in range(num_epochs):model.train() # 设置模型为训练模式train_loss = 0.0correct = 0total = 0for inputs, labels in train_loader:inputs, labels = inputs.to(device), labels.to(device)optimizer.zero_grad()outputs = model(inputs)loss = criterion(outputs, labels)loss.backward()optimizer.step()train_loss += loss.item() * inputs.size(0)_, predicted = torch.max(outputs.data, 1)total += labels.size(0)correct += (predicted == labels).sum().item()epoch_loss = train_loss / totalepoch_acc = 100. * correct / totalprint(f'Epoch {epoch+1}/{num_epochs}, Loss: {epoch_loss:.4f}, Accuracy: {epoch_acc:.2f}%')# 创建训练集的DataLoader
train_loader = DataLoader(dataset, batch_size=32, shuffle=True)# 开始训练模型
train_model(model, criterion, optimizer, num_epochs=15, train_loader=train_loader)
torch.save(model, 'my_model(1).pth')
3.2 数据集文件结构
当然,你也可以自己定义读取数据集的data_loader类。
3.3 模型推理
这段代码是用训练得到的模型对一张图片进行推理测试的,如果需要对系列图片进行推理,评估模型效果,可自行修改,调用对应函数即可。
import torch
from PIL import Image
from torchvision.transforms import Compose, Resize, CenterCrop, ToTensor, Normalize
def predict_image(image_path, model_path='my_model(1).pth'):image = Image.open(image_path).convert("RGB")# 对测试的图片进行预处理,需要和训练时处理的方式一样transform = Compose([Resize((224, 224)),CenterCrop(224),ToTensor(),Normalize(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225])])image_tensor = transform(image).unsqueeze(0)device = torch.device('cuda:0' if torch.cuda.is_available() else 'cpu')image_tensor = image_tensor.to(device)model = torch.load(model_path,map_location=device)model.eval()with torch.no_grad():output = model(image_tensor)_, predicted = torch.max(output.data, 1) # 获得分类标记return predicted.item()
if __name__=="__main__":image_path = "test2/6.jpg"print(predict_image(image_path))
3.4 整体项目文件
4、补充说明
这种利用迁移学习策略,进行少类别,不同类别特征差距小的任务需求来说,效果一般来说是比较好的。因为之前做过相关实验,准确率90%以上是很容易的,所以这里没有模型评估,生成混淆矩阵等过程。对于多类别分类,建议有完整的评估体系。
上述使用的方法仅适用于分类任务,对于真正的目标检测如手势识别,直接使用该模型的问题是:由于无法定位手势的位置,所以导致识别不准确。
本实验数据集是不同类别手势图片,为自制,不开源。
相关文章:

基于迁移学习的手势分类模型训练
1、基本原理介绍 这里介绍的单指模型迁移。一般我们训练模型时,往往会自定义一个模型类,这个类中定义了神经网络的结构,训练时将数据集输入,从0开始训练;而迁移学习中(单指模型迁移策略)&#x…...

个性化音频生成GPT-SoVits部署使用和API调用
一、训练自己的音色模型步骤 1、准备好要训练的数据,放在Data文件夹中,按照文件模板中的结构进行存放数据 2、双击打开go-webui.bat文件,等待页面跳转 3、页面打开后,开始训练自己的模型 (1)、人声伴奏分…...

MFC列表框示例
本文仅供学习交流,严禁用于商业用途,如本文涉及侵权请及时联系本人将于及时删除 目录 1.示例内容 2.程序步骤 3.运行结果 4.代码全文 1.示例内容 编写一个对话框应用程序CMFC_Li6_4_学生信息Dlg,对话框中有一个列表框,当用户…...

Android TabLayout的简单用法
TabLayout 注意这里添加tab,使用binding.tabLayout.newTab()进行创建 private fun initTabs() {val tab binding.tabLayout.newTab()tab.text "模板库"binding.tabLayout.addTab(tab)binding.tabLayout.addOnTabSelectedListener(object : TabLayout.On…...
基于vite + pnpm monorepo 实现一个UI组件库
基于vite pnpm monorepo的vue组件库 仓库地址 思路 好多文章都是直接咔咔咔的上代码。跟着做也没问题,但总觉得少了些什么。下次做的时候还要找文章参考。。 需求有三个模块,那么就需要三个包。使用monorepo进行分包管理。 a. 组件库 b. 组件库文档…...

FDM3D打印系列——Luck13关节可动模型打印和各种材料的尝试
luck13可动关节模型FDM3D打印制作过程 大家好,我是阿赵。 最近我沉迷于打印一个叫做Luck13的关节超可动人偶。 首先说明一下,这个模型是分为了外甲和骨骼两个部分的。 为什么我会打印了这么多个呢? 一、第一次尝试——PLATPU 刚开始…...
windows10 获取磁盘类型
powershell Get-PhysicalDisk | Select FriendlyName, MediaType FriendlyName MediaType ------------ --------- NVMe PC SN740 NVMe WD 256GB SSD WDC WD10EZEX-75WN4A1 HDD 适用场景 SSD: 适合需要快速访问速度和较高响…...
数据库之运算符
目录 一、算数运算符 二、比较运算符 1.常用比较运算符 2.实现特殊功能的比较运算符 三、逻辑运算符 1.逻辑与运算符(&&或者AND) 2.逻辑或运算符(||或者OR) 3.逻辑非运算符(!或者NOT&#…...
【自动化机器学习AutoML】AutoML工具和平台的使用
自动化机器学习AutoML:AutoML工具和平台的使用 目录 引言什么是AutoMLAutoML的优势常见的AutoML工具和平台 Google Cloud AutoMLH2O.aiAuto-sklearnTPOTMLBox AutoML的基本使用 Google Cloud AutoML使用示例Auto-sklearn使用示例 AutoML的应用场景结论 引言 自动…...

【每日一练】python求最后一个单词的长度
""" 求某变量中最后一个单词的长度 例如s"Good morning, champ! Youre going to rock this day" 分析思路: 遇到字符串问题,经常和列表结合使用来解决, 可以先用列表的.split()分割方法进行单词分割, 再…...

[红明谷CTF 2021]write_shell 1
目录 代码审计check()$_GET["action"] ?? "" 解题 代码审计 <?php error_reporting(0); highlight_file(__FILE__); function check($input){if(preg_match("/| |_|php|;|~|\\^|\\|eval|{|}/i",$input)){// if(preg_match("/| |_||p…...
【Go - sync.once】
sync.Once 是 Go 语言标准库中的一个结构体,它的作用是确保某个操作在全局范围内只被执行一次。这对于实现单例模式或需要一次性初始化资源的场景非常有用。 典型用法 sync.Once 提供了一个方法 Do(f func()),该方法接收一个没有参数和返回值的函数 f …...

Spark RPC框架详解
文章目录 前言Spark RPC模型概述RpcEndpointRpcEndpointRefRpcEnv 基于Netty的RPC实现NettyRpcEndpointRefNettyRpcEnv消息的发送消息的接收RpcEndpointRef的构造方式直接通过RpcEndpoint构造RpcEndpointRef通过消息发送RpcEndpointRef Endpoint的注册Dispatcher消息的投递消息…...

win10安装ElasticSearch7.x和分词插件
说明: 以下内容整理自网络,格式调整优化,更易阅读,希望能对需要的人有所帮助。 一 安装 Java环境 ElasticSearch使用Java开发的,依赖Java环境,安装 ElasticSearch 7.x 之前,需要先安装jdk-8。…...
Linux中,MySQL的用户管理
MySQL库中的表及其作用 user表 User表是MySQL中最重要的一个权限表,记录允许连接到服务器的帐号信息,里面的权限是全局级的。 db表和host表 db表和host表是MySQL数据中非常重要的权限表。db表中存储了用户对某个数据库的操作权限,决定用户…...
个人电脑网络安全 之 防浏览器和端口溢出攻击 和 权限对系统的重要性
防浏览器和端口溢出攻击 该如何防 很多人都不明白 我相信很多人只知道杀毒软件 却不知道网络防火墙 防火墙分两种 : 1、 病毒防火墙 也就是我们说的杀毒软件 2、 网络防火墙 这是用来防软件恶意通信的 使用防火墙 有两种 1、 半开式规则…...
美食聚焦 -- 仿大众点评项目技术难点总结
1 实现点赞功能显示哪些用户点赞过并安装时间顺序排序 使用sort_set 进行存储,把博客id作为key,用户id作为value,时间戳作为score 但存储成功之后还是没有成功按照时间顺序排名,因为sql语句,比如最后in(5…...

拓扑图:揭示复杂系统背后的结构与逻辑
在现代软件开发和运维中,图形化的表示方式越来越重要。拓扑图,作为一种关键的可视化工具,不仅能够帮助我们理解系统的结构和组件间的关系,还能提升系统的可维护性和可扩展性。 什么是拓扑图? 拓扑图是一种展示系统或网络中各个节点(如服务器、交换机、数据库等)及其连…...

Java面试八股之什么是spring boot starter
什么是spring boot starter Spring Boot Starter是Spring Boot项目中的一个重要概念。它是一种依赖管理机制,用于简化Maven或Gradle配置文件中的依赖项声明。Spring Boot Starter提供了一组预定义的依赖关系,这些依赖关系被封装在一个单一的包中&#x…...
探究项目未能获得ASPICE 1、2级能力的原因及改进策略
项目整体未能获得ASPICE 1、2级能力的原因可能涉及多个方面,以下是基于参考文章中的信息和可能的情境进行的分析: 1.过程成熟度不足:ASPICE(Automotive Software Process Improvement and Capability Determination)是…...

eNSP-Cloud(实现本地电脑与eNSP内设备之间通信)
说明: 想象一下,你正在用eNSP搭建一个虚拟的网络世界,里面有虚拟的路由器、交换机、电脑(PC)等等。这些设备都在你的电脑里面“运行”,它们之间可以互相通信,就像一个封闭的小王国。 但是&#…...

【人工智能】神经网络的优化器optimizer(二):Adagrad自适应学习率优化器
一.自适应梯度算法Adagrad概述 Adagrad(Adaptive Gradient Algorithm)是一种自适应学习率的优化算法,由Duchi等人在2011年提出。其核心思想是针对不同参数自动调整学习率,适合处理稀疏数据和不同参数梯度差异较大的场景。Adagrad通…...

如何在看板中体现优先级变化
在看板中有效体现优先级变化的关键措施包括:采用颜色或标签标识优先级、设置任务排序规则、使用独立的优先级列或泳道、结合自动化规则同步优先级变化、建立定期的优先级审查流程。其中,设置任务排序规则尤其重要,因为它让看板视觉上直观地体…...
渲染学进阶内容——模型
最近在写模组的时候发现渲染器里面离不开模型的定义,在渲染的第二篇文章中简单的讲解了一下关于模型部分的内容,其实不管是方块还是方块实体,都离不开模型的内容 🧱 一、CubeListBuilder 功能解析 CubeListBuilder 是 Minecraft Java 版模型系统的核心构建器,用于动态创…...

React19源码系列之 事件插件系统
事件类别 事件类型 定义 文档 Event Event 接口表示在 EventTarget 上出现的事件。 Event - Web API | MDN UIEvent UIEvent 接口表示简单的用户界面事件。 UIEvent - Web API | MDN KeyboardEvent KeyboardEvent 对象描述了用户与键盘的交互。 KeyboardEvent - Web…...
Rust 异步编程
Rust 异步编程 引言 Rust 是一种系统编程语言,以其高性能、安全性以及零成本抽象而著称。在多核处理器成为主流的今天,异步编程成为了一种提高应用性能、优化资源利用的有效手段。本文将深入探讨 Rust 异步编程的核心概念、常用库以及最佳实践。 异步编程基础 什么是异步…...

《基于Apache Flink的流处理》笔记
思维导图 1-3 章 4-7章 8-11 章 参考资料 源码: https://github.com/streaming-with-flink 博客 https://flink.apache.org/bloghttps://www.ververica.com/blog 聚会及会议 https://flink-forward.orghttps://www.meetup.com/topics/apache-flink https://n…...

NFT模式:数字资产确权与链游经济系统构建
NFT模式:数字资产确权与链游经济系统构建 ——从技术架构到可持续生态的范式革命 一、确权技术革新:构建可信数字资产基石 1. 区块链底层架构的进化 跨链互操作协议:基于LayerZero协议实现以太坊、Solana等公链资产互通,通过零知…...
数据库分批入库
今天在工作中,遇到一个问题,就是分批查询的时候,由于批次过大导致出现了一些问题,一下是问题描述和解决方案: 示例: // 假设已有数据列表 dataList 和 PreparedStatement pstmt int batchSize 1000; // …...

mac 安装homebrew (nvm 及git)
mac 安装nvm 及git 万恶之源 mac 安装这些东西离不开Xcode。及homebrew 一、先说安装git步骤 通用: 方法一:使用 Homebrew 安装 Git(推荐) 步骤如下:打开终端(Terminal.app) 1.安装 Homebrew…...