当前位置: 首页 > news >正文

基于迁移学习的手势分类模型训练

1、基本原理介绍

       这里介绍的单指模型迁移。一般我们训练模型时,往往会自定义一个模型类,这个类中定义了神经网络的结构,训练时将数据集输入,从0开始训练;而迁移学习中(单指模型迁移策略),是在一个已经有过训练基础的模型上,用自己的数据集,进一步训练,使得这个模型能够完成我们需要的任务。

这么做有有这样几个显而易见的好处:

※  因为模型之前被训练过,所以初始参数不会是0,这样能够加速模型训练

※  因为预训练模型(什么是预训练模型下文会讲到)在其他数据集上训练过,而其他数据集往往和我们用的数据集存在一定的区别,所以这可以提高模型的泛化能力

※  通过迁移学习,可以将来自大规模数据的优势转移到小规模或新任务上,提高模型的表现和效果

2、预训练模型

        在进行迁移学习时,我们要先找到一个预训练模型。在分类任务领域,比较流行的如resnet系列、mobilenet系列(更轻量化)、vgg(系列)、efficientnet(系列)等等网络,都是比较常用且容易获得的预训练模型,这些模型都能够通过python直接下载。

        而且由于上述模型基本都是在ImageNet这一大规模,多分类类别的数据集上进行过训练的,所以对于简单的二分类等少数类别分类,能有较好的效果。

3、训练流程

迁移学习完整的训练流程和一般搭建神经网络的训练模型的流程基本类似:数据预处理->数据集的切分->加载预训练模型(搭建神经网络)->设置超参数/损失函数/优化器等->训练模型

3.1 模型训练

下面的代码是一个利用mobilenet网络训练得到的手势分类模型,该模型能够较准确的分类不同类别手势。

相关解释已在代码中注释说明。

from torchvision.models import mobilenet_v2
import torch
import torch.nn as nn
import torch.optim as optim
from torch.utils.data import DataLoader
from torchvision.datasets import ImageFolder
from torchvision.transforms import Compose, Resize, CenterCrop, ToTensor, Normalize, RandomHorizontalFlip, RandomRotation# 定义数据预处理和增强器
transform = Compose([RandomHorizontalFlip(),  # 随机水平翻转RandomRotation(10),      # 随机旋转10度Resize((224, 224)),CenterCrop(224),ToTensor(),Normalize(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225])
])# 加载数据集并应用预处理和增强器
dataset = ImageFolder(root='data', transform=transform)
# 这里由于数据比较少,将所有数据集全部用来训练,得到的模型直接拿来用了,这其实不算是非常规范的操作,仅供参考# 定义网络结构
model = mobilenet_v2(pretrained=True)  # 加载预训练模型,也可以试试其他模型,效果差别挺大的
num_ftrs = model.classifier[1].in_features
model.classifier[1] = nn.Linear(num_ftrs, 5)  # 假设是5分类问题,具体几分类,改这里的参数就行了# 将模型移动到设备上
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
model = model.to(device)# 定义优化器和损失函数
optimizer = optim.SGD(model.parameters(), lr=0.001, momentum=0.9, weight_decay=1e-5)
criterion = nn.CrossEntropyLoss()# 定义训练循环
def train_model(model, criterion, optimizer, num_epochs, train_loader):for epoch in range(num_epochs):model.train()  # 设置模型为训练模式train_loss = 0.0correct = 0total = 0for inputs, labels in train_loader:inputs, labels = inputs.to(device), labels.to(device)optimizer.zero_grad()outputs = model(inputs)loss = criterion(outputs, labels)loss.backward()optimizer.step()train_loss += loss.item() * inputs.size(0)_, predicted = torch.max(outputs.data, 1)total += labels.size(0)correct += (predicted == labels).sum().item()epoch_loss = train_loss / totalepoch_acc = 100. * correct / totalprint(f'Epoch {epoch+1}/{num_epochs}, Loss: {epoch_loss:.4f}, Accuracy: {epoch_acc:.2f}%')# 创建训练集的DataLoader
train_loader = DataLoader(dataset, batch_size=32, shuffle=True)# 开始训练模型
train_model(model, criterion, optimizer, num_epochs=15, train_loader=train_loader)
torch.save(model, 'my_model(1).pth')

3.2 数据集文件结构

当然,你也可以自己定义读取数据集的data_loader类。

3.3 模型推理

这段代码是用训练得到的模型对一张图片进行推理测试的,如果需要对系列图片进行推理,评估模型效果,可自行修改,调用对应函数即可。

import torch
from PIL import Image
from torchvision.transforms import Compose, Resize, CenterCrop, ToTensor, Normalize
def predict_image(image_path, model_path='my_model(1).pth'):image = Image.open(image_path).convert("RGB")# 对测试的图片进行预处理,需要和训练时处理的方式一样transform = Compose([Resize((224, 224)),CenterCrop(224),ToTensor(),Normalize(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225])])image_tensor = transform(image).unsqueeze(0)device = torch.device('cuda:0' if torch.cuda.is_available() else 'cpu')image_tensor = image_tensor.to(device)model = torch.load(model_path,map_location=device)model.eval()with torch.no_grad():output = model(image_tensor)_, predicted = torch.max(output.data, 1)  # 获得分类标记return predicted.item()
if __name__=="__main__":image_path = "test2/6.jpg"print(predict_image(image_path))

3.4 整体项目文件

4、补充说明

        这种利用迁移学习策略,进行少类别,不同类别特征差距小的任务需求来说,效果一般来说是比较好的。因为之前做过相关实验,准确率90%以上是很容易的,所以这里没有模型评估,生成混淆矩阵等过程。对于多类别分类,建议有完整的评估体系。

        上述使用的方法仅适用于分类任务,对于真正的目标检测如手势识别,直接使用该模型的问题是:由于无法定位手势的位置,所以导致识别不准确。

        本实验数据集是不同类别手势图片,为自制,不开源。

相关文章:

基于迁移学习的手势分类模型训练

1、基本原理介绍 这里介绍的单指模型迁移。一般我们训练模型时,往往会自定义一个模型类,这个类中定义了神经网络的结构,训练时将数据集输入,从0开始训练;而迁移学习中(单指模型迁移策略)&#x…...

个性化音频生成GPT-SoVits部署使用和API调用

一、训练自己的音色模型步骤 1、准备好要训练的数据,放在Data文件夹中,按照文件模板中的结构进行存放数据 2、双击打开go-webui.bat文件,等待页面跳转 3、页面打开后,开始训练自己的模型 (1)、人声伴奏分…...

MFC列表框示例

本文仅供学习交流,严禁用于商业用途,如本文涉及侵权请及时联系本人将于及时删除 目录 1.示例内容 2.程序步骤 3.运行结果 4.代码全文 1.示例内容 编写一个对话框应用程序CMFC_Li6_4_学生信息Dlg,对话框中有一个列表框,当用户…...

Android TabLayout的简单用法

TabLayout 注意这里添加tab,使用binding.tabLayout.newTab()进行创建 private fun initTabs() {val tab binding.tabLayout.newTab()tab.text "模板库"binding.tabLayout.addTab(tab)binding.tabLayout.addOnTabSelectedListener(object : TabLayout.On…...

基于vite + pnpm monorepo 实现一个UI组件库

基于vite pnpm monorepo的vue组件库 仓库地址 思路 好多文章都是直接咔咔咔的上代码。跟着做也没问题,但总觉得少了些什么。下次做的时候还要找文章参考。。 需求有三个模块,那么就需要三个包。使用monorepo进行分包管理。 a. 组件库 b. 组件库文档…...

FDM3D打印系列——Luck13关节可动模型打印和各种材料的尝试

luck13可动关节模型FDM3D打印制作过程 大家好,我是阿赵。   最近我沉迷于打印一个叫做Luck13的关节超可动人偶。 首先说明一下,这个模型是分为了外甲和骨骼两个部分的。   为什么我会打印了这么多个呢? 一、第一次尝试——PLATPU 刚开始…...

windows10 获取磁盘类型

powershell Get-PhysicalDisk | Select FriendlyName, MediaType FriendlyName MediaType ------------ --------- NVMe PC SN740 NVMe WD 256GB SSD WDC WD10EZEX-75WN4A1 HDD 适用场景 SSD: 适合需要快速访问速度和较高响…...

数据库之运算符

目录 一、算数运算符 二、比较运算符 1.常用比较运算符 2.实现特殊功能的比较运算符 三、逻辑运算符 1.逻辑与运算符(&&或者AND) 2.逻辑或运算符(||或者OR) 3.逻辑非运算符(!或者NOT&#…...

【自动化机器学习AutoML】AutoML工具和平台的使用

自动化机器学习AutoML:AutoML工具和平台的使用 目录 引言什么是AutoMLAutoML的优势常见的AutoML工具和平台 Google Cloud AutoMLH2O.aiAuto-sklearnTPOTMLBox AutoML的基本使用 Google Cloud AutoML使用示例Auto-sklearn使用示例 AutoML的应用场景结论 引言 自动…...

【每日一练】python求最后一个单词的长度

""" 求某变量中最后一个单词的长度 例如s"Good morning, champ! Youre going to rock this day" 分析思路: 遇到字符串问题,经常和列表结合使用来解决, 可以先用列表的.split()分割方法进行单词分割, 再…...

[红明谷CTF 2021]write_shell 1

目录 代码审计check()$_GET["action"] ?? "" 解题 代码审计 <?php error_reporting(0); highlight_file(__FILE__); function check($input){if(preg_match("/| |_|php|;|~|\\^|\\|eval|{|}/i",$input)){// if(preg_match("/| |_||p…...

【Go - sync.once】

sync.Once 是 Go 语言标准库中的一个结构体&#xff0c;它的作用是确保某个操作在全局范围内只被执行一次。这对于实现单例模式或需要一次性初始化资源的场景非常有用。 典型用法 sync.Once 提供了一个方法 Do(f func())&#xff0c;该方法接收一个没有参数和返回值的函数 f …...

Spark RPC框架详解

文章目录 前言Spark RPC模型概述RpcEndpointRpcEndpointRefRpcEnv 基于Netty的RPC实现NettyRpcEndpointRefNettyRpcEnv消息的发送消息的接收RpcEndpointRef的构造方式直接通过RpcEndpoint构造RpcEndpointRef通过消息发送RpcEndpointRef Endpoint的注册Dispatcher消息的投递消息…...

win10安装ElasticSearch7.x和分词插件

说明&#xff1a; 以下内容整理自网络&#xff0c;格式调整优化&#xff0c;更易阅读&#xff0c;希望能对需要的人有所帮助。 一 安装 Java环境 ElasticSearch使用Java开发的&#xff0c;依赖Java环境&#xff0c;安装 ElasticSearch 7.x 之前&#xff0c;需要先安装jdk-8。…...

Linux中,MySQL的用户管理

MySQL库中的表及其作用 user表 User表是MySQL中最重要的一个权限表&#xff0c;记录允许连接到服务器的帐号信息&#xff0c;里面的权限是全局级的。 db表和host表 db表和host表是MySQL数据中非常重要的权限表。db表中存储了用户对某个数据库的操作权限&#xff0c;决定用户…...

个人电脑网络安全 之 防浏览器和端口溢出攻击 和 权限对系统的重要性

防浏览器和端口溢出攻击 该如何防 很多人都不明白 我相信很多人只知道杀毒软件 却不知道网络防火墙 防火墙分两种 &#xff1a; 1、 病毒防火墙 也就是我们说的杀毒软件 2、 网络防火墙 这是用来防软件恶意通信的 使用防火墙 有两种 1、 半开式规则…...

美食聚焦 -- 仿大众点评项目技术难点总结

1 实现点赞功能显示哪些用户点赞过并安装时间顺序排序 使用sort_set 进行存储&#xff0c;把博客id作为key&#xff0c;用户id作为value&#xff0c;时间戳作为score 但存储成功之后还是没有成功按照时间顺序排名&#xff0c;因为sql语句&#xff0c;比如最后in&#xff08;5…...

拓扑图:揭示复杂系统背后的结构与逻辑

在现代软件开发和运维中,图形化的表示方式越来越重要。拓扑图,作为一种关键的可视化工具,不仅能够帮助我们理解系统的结构和组件间的关系,还能提升系统的可维护性和可扩展性。 什么是拓扑图? 拓扑图是一种展示系统或网络中各个节点(如服务器、交换机、数据库等)及其连…...

Java面试八股之什么是spring boot starter

什么是spring boot starter Spring Boot Starter是Spring Boot项目中的一个重要概念。它是一种依赖管理机制&#xff0c;用于简化Maven或Gradle配置文件中的依赖项声明。Spring Boot Starter提供了一组预定义的依赖关系&#xff0c;这些依赖关系被封装在一个单一的包中&#x…...

探究项目未能获得ASPICE 1、2级能力的原因及改进策略

项目整体未能获得ASPICE 1、2级能力的原因可能涉及多个方面&#xff0c;以下是基于参考文章中的信息和可能的情境进行的分析&#xff1a; 1.过程成熟度不足&#xff1a;ASPICE&#xff08;Automotive Software Process Improvement and Capability Determination&#xff09;是…...

挑战杯推荐项目

“人工智能”创意赛 - 智能艺术创作助手&#xff1a;借助大模型技术&#xff0c;开发能根据用户输入的主题、风格等要求&#xff0c;生成绘画、音乐、文学作品等多种形式艺术创作灵感或初稿的应用&#xff0c;帮助艺术家和创意爱好者激发创意、提高创作效率。 ​ - 个性化梦境…...

2025 后端自学UNIAPP【项目实战:旅游项目】6、我的收藏页面

代码框架视图 1、先添加一个获取收藏景点的列表请求 【在文件my_api.js文件中添加】 // 引入公共的请求封装 import http from ./my_http.js// 登录接口&#xff08;适配服务端返回 Token&#xff09; export const login async (code, avatar) > {const res await http…...

LeetCode - 199. 二叉树的右视图

题目 199. 二叉树的右视图 - 力扣&#xff08;LeetCode&#xff09; 思路 右视图是指从树的右侧看&#xff0c;对于每一层&#xff0c;只能看到该层最右边的节点。实现思路是&#xff1a; 使用深度优先搜索(DFS)按照"根-右-左"的顺序遍历树记录每个节点的深度对于…...

NPOI操作EXCEL文件 ——CAD C# 二次开发

缺点:dll.版本容易加载错误。CAD加载插件时&#xff0c;没有加载所有类库。插件运行过程中用到某个类库&#xff0c;会从CAD的安装目录找&#xff0c;找不到就报错了。 【方案2】让CAD在加载过程中把类库加载到内存 【方案3】是发现缺少了哪个库&#xff0c;就用插件程序加载进…...

群晖NAS如何在虚拟机创建飞牛NAS

套件中心下载安装Virtual Machine Manager 创建虚拟机 配置虚拟机 飞牛官网下载 https://iso.liveupdate.fnnas.com/x86_64/trim/fnos-0.9.2-863.iso 群晖NAS如何在虚拟机创建飞牛NAS - 个人信息分享...

为什么要创建 Vue 实例

核心原因:Vue 需要一个「控制中心」来驱动整个应用 你可以把 Vue 实例想象成你应用的**「大脑」或「引擎」。它负责协调模板、数据、逻辑和行为,将它们变成一个活的、可交互的应用**。没有这个实例,你的代码只是一堆静态的 HTML、JavaScript 变量和函数,无法「活」起来。 …...

Web后端基础(基础知识)

BS架构&#xff1a;Browser/Server&#xff0c;浏览器/服务器架构模式。客户端只需要浏览器&#xff0c;应用程序的逻辑和数据都存储在服务端。 优点&#xff1a;维护方便缺点&#xff1a;体验一般 CS架构&#xff1a;Client/Server&#xff0c;客户端/服务器架构模式。需要单独…...

智能职业发展系统:AI驱动的职业规划平台技术解析

智能职业发展系统&#xff1a;AI驱动的职业规划平台技术解析 引言&#xff1a;数字时代的职业革命 在当今瞬息万变的就业市场中&#xff0c;传统的职业规划方法已无法满足个人和企业的需求。据统计&#xff0c;全球每年有超过2亿人面临职业转型困境&#xff0c;而企业也因此遭…...

向量几何的二元性:叉乘模长与内积投影的深层联系

在数学与物理的空间世界中&#xff0c;向量运算构成了理解几何结构的基石。叉乘&#xff08;外积&#xff09;与点积&#xff08;内积&#xff09;作为向量代数的两大支柱&#xff0c;表面上呈现出截然不同的几何意义与代数形式&#xff0c;却在深层次上揭示了向量间相互作用的…...

Qt Quick Controls模块功能及架构

Qt Quick Controls是Qt Quick的一个附加模块&#xff0c;提供了一套用于构建完整用户界面的UI控件。在Qt 6.0中&#xff0c;这个模块经历了重大重构和改进。 一、主要功能和特点 1. 架构重构 完全重写了底层架构&#xff0c;与Qt Quick更紧密集成 移除了对Qt Widgets的依赖&…...