当前位置: 首页 > news >正文

基于迁移学习的手势分类模型训练

1、基本原理介绍

       这里介绍的单指模型迁移。一般我们训练模型时,往往会自定义一个模型类,这个类中定义了神经网络的结构,训练时将数据集输入,从0开始训练;而迁移学习中(单指模型迁移策略),是在一个已经有过训练基础的模型上,用自己的数据集,进一步训练,使得这个模型能够完成我们需要的任务。

这么做有有这样几个显而易见的好处:

※  因为模型之前被训练过,所以初始参数不会是0,这样能够加速模型训练

※  因为预训练模型(什么是预训练模型下文会讲到)在其他数据集上训练过,而其他数据集往往和我们用的数据集存在一定的区别,所以这可以提高模型的泛化能力

※  通过迁移学习,可以将来自大规模数据的优势转移到小规模或新任务上,提高模型的表现和效果

2、预训练模型

        在进行迁移学习时,我们要先找到一个预训练模型。在分类任务领域,比较流行的如resnet系列、mobilenet系列(更轻量化)、vgg(系列)、efficientnet(系列)等等网络,都是比较常用且容易获得的预训练模型,这些模型都能够通过python直接下载。

        而且由于上述模型基本都是在ImageNet这一大规模,多分类类别的数据集上进行过训练的,所以对于简单的二分类等少数类别分类,能有较好的效果。

3、训练流程

迁移学习完整的训练流程和一般搭建神经网络的训练模型的流程基本类似:数据预处理->数据集的切分->加载预训练模型(搭建神经网络)->设置超参数/损失函数/优化器等->训练模型

3.1 模型训练

下面的代码是一个利用mobilenet网络训练得到的手势分类模型,该模型能够较准确的分类不同类别手势。

相关解释已在代码中注释说明。

from torchvision.models import mobilenet_v2
import torch
import torch.nn as nn
import torch.optim as optim
from torch.utils.data import DataLoader
from torchvision.datasets import ImageFolder
from torchvision.transforms import Compose, Resize, CenterCrop, ToTensor, Normalize, RandomHorizontalFlip, RandomRotation# 定义数据预处理和增强器
transform = Compose([RandomHorizontalFlip(),  # 随机水平翻转RandomRotation(10),      # 随机旋转10度Resize((224, 224)),CenterCrop(224),ToTensor(),Normalize(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225])
])# 加载数据集并应用预处理和增强器
dataset = ImageFolder(root='data', transform=transform)
# 这里由于数据比较少,将所有数据集全部用来训练,得到的模型直接拿来用了,这其实不算是非常规范的操作,仅供参考# 定义网络结构
model = mobilenet_v2(pretrained=True)  # 加载预训练模型,也可以试试其他模型,效果差别挺大的
num_ftrs = model.classifier[1].in_features
model.classifier[1] = nn.Linear(num_ftrs, 5)  # 假设是5分类问题,具体几分类,改这里的参数就行了# 将模型移动到设备上
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
model = model.to(device)# 定义优化器和损失函数
optimizer = optim.SGD(model.parameters(), lr=0.001, momentum=0.9, weight_decay=1e-5)
criterion = nn.CrossEntropyLoss()# 定义训练循环
def train_model(model, criterion, optimizer, num_epochs, train_loader):for epoch in range(num_epochs):model.train()  # 设置模型为训练模式train_loss = 0.0correct = 0total = 0for inputs, labels in train_loader:inputs, labels = inputs.to(device), labels.to(device)optimizer.zero_grad()outputs = model(inputs)loss = criterion(outputs, labels)loss.backward()optimizer.step()train_loss += loss.item() * inputs.size(0)_, predicted = torch.max(outputs.data, 1)total += labels.size(0)correct += (predicted == labels).sum().item()epoch_loss = train_loss / totalepoch_acc = 100. * correct / totalprint(f'Epoch {epoch+1}/{num_epochs}, Loss: {epoch_loss:.4f}, Accuracy: {epoch_acc:.2f}%')# 创建训练集的DataLoader
train_loader = DataLoader(dataset, batch_size=32, shuffle=True)# 开始训练模型
train_model(model, criterion, optimizer, num_epochs=15, train_loader=train_loader)
torch.save(model, 'my_model(1).pth')

3.2 数据集文件结构

当然,你也可以自己定义读取数据集的data_loader类。

3.3 模型推理

这段代码是用训练得到的模型对一张图片进行推理测试的,如果需要对系列图片进行推理,评估模型效果,可自行修改,调用对应函数即可。

import torch
from PIL import Image
from torchvision.transforms import Compose, Resize, CenterCrop, ToTensor, Normalize
def predict_image(image_path, model_path='my_model(1).pth'):image = Image.open(image_path).convert("RGB")# 对测试的图片进行预处理,需要和训练时处理的方式一样transform = Compose([Resize((224, 224)),CenterCrop(224),ToTensor(),Normalize(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225])])image_tensor = transform(image).unsqueeze(0)device = torch.device('cuda:0' if torch.cuda.is_available() else 'cpu')image_tensor = image_tensor.to(device)model = torch.load(model_path,map_location=device)model.eval()with torch.no_grad():output = model(image_tensor)_, predicted = torch.max(output.data, 1)  # 获得分类标记return predicted.item()
if __name__=="__main__":image_path = "test2/6.jpg"print(predict_image(image_path))

3.4 整体项目文件

4、补充说明

        这种利用迁移学习策略,进行少类别,不同类别特征差距小的任务需求来说,效果一般来说是比较好的。因为之前做过相关实验,准确率90%以上是很容易的,所以这里没有模型评估,生成混淆矩阵等过程。对于多类别分类,建议有完整的评估体系。

        上述使用的方法仅适用于分类任务,对于真正的目标检测如手势识别,直接使用该模型的问题是:由于无法定位手势的位置,所以导致识别不准确。

        本实验数据集是不同类别手势图片,为自制,不开源。

相关文章:

基于迁移学习的手势分类模型训练

1、基本原理介绍 这里介绍的单指模型迁移。一般我们训练模型时,往往会自定义一个模型类,这个类中定义了神经网络的结构,训练时将数据集输入,从0开始训练;而迁移学习中(单指模型迁移策略)&#x…...

个性化音频生成GPT-SoVits部署使用和API调用

一、训练自己的音色模型步骤 1、准备好要训练的数据,放在Data文件夹中,按照文件模板中的结构进行存放数据 2、双击打开go-webui.bat文件,等待页面跳转 3、页面打开后,开始训练自己的模型 (1)、人声伴奏分…...

MFC列表框示例

本文仅供学习交流,严禁用于商业用途,如本文涉及侵权请及时联系本人将于及时删除 目录 1.示例内容 2.程序步骤 3.运行结果 4.代码全文 1.示例内容 编写一个对话框应用程序CMFC_Li6_4_学生信息Dlg,对话框中有一个列表框,当用户…...

Android TabLayout的简单用法

TabLayout 注意这里添加tab,使用binding.tabLayout.newTab()进行创建 private fun initTabs() {val tab binding.tabLayout.newTab()tab.text "模板库"binding.tabLayout.addTab(tab)binding.tabLayout.addOnTabSelectedListener(object : TabLayout.On…...

基于vite + pnpm monorepo 实现一个UI组件库

基于vite pnpm monorepo的vue组件库 仓库地址 思路 好多文章都是直接咔咔咔的上代码。跟着做也没问题,但总觉得少了些什么。下次做的时候还要找文章参考。。 需求有三个模块,那么就需要三个包。使用monorepo进行分包管理。 a. 组件库 b. 组件库文档…...

FDM3D打印系列——Luck13关节可动模型打印和各种材料的尝试

luck13可动关节模型FDM3D打印制作过程 大家好,我是阿赵。   最近我沉迷于打印一个叫做Luck13的关节超可动人偶。 首先说明一下,这个模型是分为了外甲和骨骼两个部分的。   为什么我会打印了这么多个呢? 一、第一次尝试——PLATPU 刚开始…...

windows10 获取磁盘类型

powershell Get-PhysicalDisk | Select FriendlyName, MediaType FriendlyName MediaType ------------ --------- NVMe PC SN740 NVMe WD 256GB SSD WDC WD10EZEX-75WN4A1 HDD 适用场景 SSD: 适合需要快速访问速度和较高响…...

数据库之运算符

目录 一、算数运算符 二、比较运算符 1.常用比较运算符 2.实现特殊功能的比较运算符 三、逻辑运算符 1.逻辑与运算符(&&或者AND) 2.逻辑或运算符(||或者OR) 3.逻辑非运算符(!或者NOT&#…...

【自动化机器学习AutoML】AutoML工具和平台的使用

自动化机器学习AutoML:AutoML工具和平台的使用 目录 引言什么是AutoMLAutoML的优势常见的AutoML工具和平台 Google Cloud AutoMLH2O.aiAuto-sklearnTPOTMLBox AutoML的基本使用 Google Cloud AutoML使用示例Auto-sklearn使用示例 AutoML的应用场景结论 引言 自动…...

【每日一练】python求最后一个单词的长度

""" 求某变量中最后一个单词的长度 例如s"Good morning, champ! Youre going to rock this day" 分析思路: 遇到字符串问题,经常和列表结合使用来解决, 可以先用列表的.split()分割方法进行单词分割, 再…...

[红明谷CTF 2021]write_shell 1

目录 代码审计check()$_GET["action"] ?? "" 解题 代码审计 <?php error_reporting(0); highlight_file(__FILE__); function check($input){if(preg_match("/| |_|php|;|~|\\^|\\|eval|{|}/i",$input)){// if(preg_match("/| |_||p…...

【Go - sync.once】

sync.Once 是 Go 语言标准库中的一个结构体&#xff0c;它的作用是确保某个操作在全局范围内只被执行一次。这对于实现单例模式或需要一次性初始化资源的场景非常有用。 典型用法 sync.Once 提供了一个方法 Do(f func())&#xff0c;该方法接收一个没有参数和返回值的函数 f …...

Spark RPC框架详解

文章目录 前言Spark RPC模型概述RpcEndpointRpcEndpointRefRpcEnv 基于Netty的RPC实现NettyRpcEndpointRefNettyRpcEnv消息的发送消息的接收RpcEndpointRef的构造方式直接通过RpcEndpoint构造RpcEndpointRef通过消息发送RpcEndpointRef Endpoint的注册Dispatcher消息的投递消息…...

win10安装ElasticSearch7.x和分词插件

说明&#xff1a; 以下内容整理自网络&#xff0c;格式调整优化&#xff0c;更易阅读&#xff0c;希望能对需要的人有所帮助。 一 安装 Java环境 ElasticSearch使用Java开发的&#xff0c;依赖Java环境&#xff0c;安装 ElasticSearch 7.x 之前&#xff0c;需要先安装jdk-8。…...

Linux中,MySQL的用户管理

MySQL库中的表及其作用 user表 User表是MySQL中最重要的一个权限表&#xff0c;记录允许连接到服务器的帐号信息&#xff0c;里面的权限是全局级的。 db表和host表 db表和host表是MySQL数据中非常重要的权限表。db表中存储了用户对某个数据库的操作权限&#xff0c;决定用户…...

个人电脑网络安全 之 防浏览器和端口溢出攻击 和 权限对系统的重要性

防浏览器和端口溢出攻击 该如何防 很多人都不明白 我相信很多人只知道杀毒软件 却不知道网络防火墙 防火墙分两种 &#xff1a; 1、 病毒防火墙 也就是我们说的杀毒软件 2、 网络防火墙 这是用来防软件恶意通信的 使用防火墙 有两种 1、 半开式规则…...

美食聚焦 -- 仿大众点评项目技术难点总结

1 实现点赞功能显示哪些用户点赞过并安装时间顺序排序 使用sort_set 进行存储&#xff0c;把博客id作为key&#xff0c;用户id作为value&#xff0c;时间戳作为score 但存储成功之后还是没有成功按照时间顺序排名&#xff0c;因为sql语句&#xff0c;比如最后in&#xff08;5…...

拓扑图:揭示复杂系统背后的结构与逻辑

在现代软件开发和运维中,图形化的表示方式越来越重要。拓扑图,作为一种关键的可视化工具,不仅能够帮助我们理解系统的结构和组件间的关系,还能提升系统的可维护性和可扩展性。 什么是拓扑图? 拓扑图是一种展示系统或网络中各个节点(如服务器、交换机、数据库等)及其连…...

Java面试八股之什么是spring boot starter

什么是spring boot starter Spring Boot Starter是Spring Boot项目中的一个重要概念。它是一种依赖管理机制&#xff0c;用于简化Maven或Gradle配置文件中的依赖项声明。Spring Boot Starter提供了一组预定义的依赖关系&#xff0c;这些依赖关系被封装在一个单一的包中&#x…...

探究项目未能获得ASPICE 1、2级能力的原因及改进策略

项目整体未能获得ASPICE 1、2级能力的原因可能涉及多个方面&#xff0c;以下是基于参考文章中的信息和可能的情境进行的分析&#xff1a; 1.过程成熟度不足&#xff1a;ASPICE&#xff08;Automotive Software Process Improvement and Capability Determination&#xff09;是…...

Java 8 Stream API 入门到实践详解

一、告别 for 循环&#xff01; 传统痛点&#xff1a; Java 8 之前&#xff0c;集合操作离不开冗长的 for 循环和匿名类。例如&#xff0c;过滤列表中的偶数&#xff1a; List<Integer> list Arrays.asList(1, 2, 3, 4, 5); List<Integer> evens new ArrayList…...

Vue2 第一节_Vue2上手_插值表达式{{}}_访问数据和修改数据_Vue开发者工具

文章目录 1.Vue2上手-如何创建一个Vue实例,进行初始化渲染2. 插值表达式{{}}3. 访问数据和修改数据4. vue响应式5. Vue开发者工具--方便调试 1.Vue2上手-如何创建一个Vue实例,进行初始化渲染 准备容器引包创建Vue实例 new Vue()指定配置项 ->渲染数据 准备一个容器,例如: …...

OkHttp 中实现断点续传 demo

在 OkHttp 中实现断点续传主要通过以下步骤完成&#xff0c;核心是利用 HTTP 协议的 Range 请求头指定下载范围&#xff1a; 实现原理 Range 请求头&#xff1a;向服务器请求文件的特定字节范围&#xff08;如 Range: bytes1024-&#xff09; 本地文件记录&#xff1a;保存已…...

C++ 求圆面积的程序(Program to find area of a circle)

给定半径r&#xff0c;求圆的面积。圆的面积应精确到小数点后5位。 例子&#xff1a; 输入&#xff1a;r 5 输出&#xff1a;78.53982 解释&#xff1a;由于面积 PI * r * r 3.14159265358979323846 * 5 * 5 78.53982&#xff0c;因为我们只保留小数点后 5 位数字。 输…...

mysql已经安装,但是通过rpm -q 没有找mysql相关的已安装包

文章目录 现象&#xff1a;mysql已经安装&#xff0c;但是通过rpm -q 没有找mysql相关的已安装包遇到 rpm 命令找不到已经安装的 MySQL 包时&#xff0c;可能是因为以下几个原因&#xff1a;1.MySQL 不是通过 RPM 包安装的2.RPM 数据库损坏3.使用了不同的包名或路径4.使用其他包…...

蓝桥杯3498 01串的熵

问题描述 对于一个长度为 23333333的 01 串, 如果其信息熵为 11625907.5798&#xff0c; 且 0 出现次数比 1 少, 那么这个 01 串中 0 出现了多少次? #include<iostream> #include<cmath> using namespace std;int n 23333333;int main() {//枚举 0 出现的次数//因…...

USB Over IP专用硬件的5个特点

USB over IP技术通过将USB协议数据封装在标准TCP/IP网络数据包中&#xff0c;从根本上改变了USB连接。这允许客户端通过局域网或广域网远程访问和控制物理连接到服务器的USB设备&#xff08;如专用硬件设备&#xff09;&#xff0c;从而消除了直接物理连接的需要。USB over IP的…...

Docker 本地安装 mysql 数据库

Docker: Accelerated Container Application Development 下载对应操作系统版本的 docker &#xff1b;并安装。 基础操作不再赘述。 打开 macOS 终端&#xff0c;开始 docker 安装mysql之旅 第一步 docker search mysql 》〉docker search mysql NAME DE…...

08. C#入门系列【类的基本概念】:开启编程世界的奇妙冒险

C#入门系列【类的基本概念】&#xff1a;开启编程世界的奇妙冒险 嘿&#xff0c;各位编程小白探险家&#xff01;欢迎来到 C# 的奇幻大陆&#xff01;今天咱们要深入探索这片大陆上至关重要的 “建筑”—— 类&#xff01;别害怕&#xff0c;跟着我&#xff0c;保准让你轻松搞…...

Python Einops库:深度学习中的张量操作革命

Einops&#xff08;爱因斯坦操作库&#xff09;就像给张量操作戴上了一副"语义眼镜"——让你用人类能理解的方式告诉计算机如何操作多维数组。这个基于爱因斯坦求和约定的库&#xff0c;用类似自然语言的表达式替代了晦涩的API调用&#xff0c;彻底改变了深度学习工程…...