当前位置: 首页 > news >正文

视频与AI,与进程交互(二) pytorch 极简训练自己的数据集并识别

目标学习任务

检测出已经分割出的图像的分类

2 使用pytorch

pytorch 非常简单就可以做到训练和加载

2.1 准备数据

在这里插入图片描述
如上图所示,用来训练的文件放在了train中,验证的文件放在val中,train.txt 和 val.txt 分别放文件名称和分类类别,然后我们在代码中写名字就行

里面我就为了做一个例子,放了两种文件,1 是 卡宴保时捷,2 是工程车,如下图所示
在这里插入图片描述
train.txt 如下图所示
在这里插入图片描述
val.txt 也是同样如此

3 show me the code

3.1 装载数据类

新增一个loaddata.py 文件

import torch
import random
from PIL import Image
class LoadData(torch.utils.data.Dataset):def __init__(self, root, datatxt, transform=None, target_transform=None):super(LoadData, self).__init__()file_txt = open(datatxt,'r')imgs = []for line in file_txt:line = line.rstrip()words = line.split('|')imgs.append((words[0], words[1]))self.imgs = imgsself.root = rootself.transform = transformself.target_transform = target_transformdef __getitem__(self, index):random.shuffle(self.imgs)name, label = self.imgs[index]img = Image.open(self.root + name).convert('RGB')if self.transform is not None:img = self.transform(img)label = int(label)return img, labeldef __len__(self):return len(self.imgs)

LoadData 类是从torch.util.data.Dataset上继承下来的,需要一个transform类输入,实际上就是转化大小

3.2 网络类

定义一个网络类,只有两个输出

import torch.nn as nn
import torch.nn.functional as F
import torch.optim as optimclass Net(nn.Module):def __init__(self):super(Net, self).__init__()self.conv1 = nn.Conv2d(3, 16, 3)self.pool = nn.MaxPool2d((2, 2))self.pool1 = nn.MaxPool2d((2, 2))self.conv2 = nn.Conv2d(16, 32, 3)self.fc1 = nn.Linear(36*36*32, 120)self.fc2 = nn.Linear(120, 60)self.fc3 = nn.Linear(60, 2)def forward(self, x):x = self.pool(F.relu(self.conv1(x)))x = self.pool1(F.relu(self.conv2(x)))x = x.view(-1, 36*36*32)x = F.relu(self.fc1(x))x = F.relu(self.fc2(x))x = self.fc3(x)return x

3.3 主要流程

import torch
from PIL import Image
import torchvision.transforms as transforms
import matplotlib.pyplot as plt
import numpy as np
import torch.nn as nn
import torch.optim as optim
from loaddata import LoadData
from modelnet import Netdevice = torch.device("cuda:0" if torch.cuda.is_available() else "cpu")
print(device)classes = ['工程车','卡宴']
transform = transforms.Compose([transforms.Resize((152, 152)),transforms.ToTensor(),transforms.Normalize((0.5, 0.5, 0.5), (0.5, 0.5, 0.5))])
train_data=LoadData(root ='./data/train/',datatxt='./data/'+'train.txt',transform=transform)
test_data=LoadData(root ='./data/val/',datatxt='./data/'+'val.txt',transform=transform)
train_loader = torch.utils.data.DataLoader(dataset=train_data, batch_size=2, shuffle=True)
test_loader = torch.utils.data.DataLoader(dataset=test_data, batch_size=2)def imshow(img):img = img / 2 + 0.5     # unnormalizenpimg = img.numpy()plt.imshow(np.transpose(npimg, (1, 2, 0)))plt.show()net = Net()
criterion = nn.CrossEntropyLoss()
optimizer = optim.SGD(net.parameters(), lr=0.001, momentum=0.9)for epoch in range(10):running_loss = 0.0for i, data in enumerate(train_loader, 0):inputs, labels = dataoptimizer.zero_grad()outputs = net(inputs)loss = criterion(outputs, labels)loss.backward()optimizer.step()running_loss += loss.item()if i % 200 == 0:print('[%d, %5d] loss: %.3f' %(epoch + 1, i + 1, running_loss / 200))running_loss = 0.0print('Finished Training')PATH = './test.pth'
torch.save(net.state_dict(), PATH)net = Net()
net.load_state_dict(torch.load(PATH))correct = 0
total = 0
with torch.no_grad():for data in test_loader:images, labels = dataoutputs = net(images)_, predicted = torch.max(outputs.data, 1)total += labels.size(0)correct += (predicted == labels).sum().item()print('Accuracy of the network on the test images: %d %%' % (100 * correct / total))

在这里插入图片描述
如上图所示,epoch为5时精确度为80%,为10时精确度为100%,各位不要当真,这这是训练集里面的数据集做识别,并不是真的精确度。

3.4 识别代码

import torch
from PIL import Image
import torchvision.transforms as transforms
import matplotlib.pyplot as plt
import numpy as np
import torch.nn as nn
from modelnet import NetPATH = './test.pth'
transform = transforms.Compose([transforms.Resize((152, 152)),transforms.ToTensor(),transforms.Normalize((0.5, 0.5, 0.5), (0.5, 0.5, 0.5))])net = Net()
net.load_state_dict(torch.load(PATH))img = Image.open("./data/val/102.jpg").convert('RGB')
img = transform(img)
with torch.no_grad():outputs = net(img)_, predicted = torch.max(outputs.data, 1)print("the 102 img lable is ",predicted)

如下图所示,102 为卡宴识别为1 正确
在这里插入图片描述

后记

后面我们准备是从视频中传递过来图像进行分类,同时使用我们的工具VT解码视频后进行内存共享来生成图像,而不是从磁盘加载。要用到我们的c++ 解码工具,和pytorch进行交互
以下是第一篇文章:视频与AI,与进程交互(一)
VT 工具准备开源,端午节节后开出来

相关文章:

视频与AI,与进程交互(二) pytorch 极简训练自己的数据集并识别

目标学习任务 检测出已经分割出的图像的分类 2 使用pytorch pytorch 非常简单就可以做到训练和加载 2.1 准备数据 如上图所示,用来训练的文件放在了train中,验证的文件放在val中,train.txt 和 val.txt 分别放文件名称和分类类别&#xff…...

LLM - 第2版 ChatGLM2-6B (General Language Model) 的工程配置

欢迎关注我的CSDN:https://spike.blog.csdn.net/ 本文地址:https://blog.csdn.net/caroline_wendy/article/details/131445696 ChatGLM2-6B 是开源中英双语对话模型 ChatGLM-6B 的第二代版本,在保留了初代模型对话流畅、部署门槛较低等众多优…...

从0开始,手写MySQL事务

说在前面:从0开始,手写MySQL的学习价值 尼恩曾经指导过的一个7年经验小伙,凭借精通Mysql, 搞定月薪40K。 从0开始,手写一个MySQL的学习价值在于: 可以深入地理解MySQL的内部机制和原理,Mysql可谓是面试的…...

React中useState的setState方法请求了好多次

1、问题描述 最近在写react的时候碰到了一个很奇怪的问题。 可以看到那个getXXX()的方法一直不断的被调用,网页一直请求,根本停不下来了。 2、产生原因 要弄明白这个原因,首先要先了解一下react生命周期。 react是组件式的编程,一…...

【MYSQL基础】基础命令介绍

基础命令 MYSQL注释方式 -- 单行注释/* 多行注释 哈哈哈哈哈 哈哈哈哈 */连接数据库 mysql -u root -p12345678退出数据库连接 使用exit;命令可以退出连接 查询MYSQL版本 mysql> select version(); ----------- | version() | ----------- | 8.0.27 | ----------- 1…...

多元回归预测 | Matlab基于灰狼算法优化深度置信网络(GWO-DBN)的数据回归预测,matlab代码回归预测,多变量输入模型

文章目录 效果一览文章概述部分源码参考资料效果一览 文章概述 多元回归预测 | Matlab基于灰狼算法优化深度置信网络(GWO-DBN)的数据回归预测,matlab代码回归预测,多变量输入模型,matlab代码回归预测,多变量输入模型,多变量输入模型 评价指标包括:MAE、RMSE和R2等,代码质…...

校园wifi网页认证登录入口

很多校园wifi网页认证登录入口是1.1.1.1 连上校园网在浏览器写上http://1.1.1.1就进入了校园网 使 用 说 明 一、帐户余额 < 0.00元时&#xff0c;帐号被禁用&#xff0c;需追加网费。 二、在计算中心机房上机的用户&#xff0c;登录时请选择新建帐号时给您指定的NT域&…...

[SpringBoot]Spring Security框架

目录 关于Spring Security框架 Spring Security框架的依赖项 Spring Security框架的典型特征 关于Spring Security的配置 关于默认的登录页 关于请求的授权访问&#xff08;访问控制&#xff09; 使用自定义的账号登录 使用数据库中的账号登录 关于密码编码器 使用BCry…...

Unity 之 抖音小游戏本地数据最新存储方法分享

Unity 之 抖音小游戏本地数据最新存储方法分享 一、抖音小游戏文件存储系统背景二、文件存储系统的使用方法2.1 初始化2.1 创建目录2.3 存储数据2.4 删除目录/文件2.5 其他相关操作 三&#xff0c;小结 抖音小游戏是一种基于抖音平台开发的小型游戏&#xff0c;与传统的 APP 不…...

逍遥自在学C语言 | 函数初级到高级解析

前言 函数是C语言中的基本构建块之一&#xff0c;它允许我们将代码组织成可重用、模块化的单元。 本文将逐步介绍C语言函数的基础概念、参数传递、返回值、递归以及内联函数和匿名函数。 一、人物简介 第一位闪亮登场&#xff0c;有请今后会一直教我们C语言的老师 —— 自在…...

Elastic 推出 Elastic AI 助手

作者&#xff1a;Mike Nichols Elastic 推出了 Elastic AI Assistant&#xff0c;这是一款由 ESRE 提供支持的开放式、生成式 AI 助手&#xff0c;旨在使网络安全民主化并支持各种技能水平的用户。 最近发布的 Elasticsearch Relevance Engine™ (ESRE™) 提供了用于创建高度相…...

【数据库】MySQL安装(最新图文保姆级别超详细版本介绍)

1.总共两部分&#xff08;第二部可省略&#xff09; 安装mysql体验mysql环境变量配置 1.1安装mysql 1.输入官网地址https://www.mysql.com/ 下载完成后&#xff0c;我们双击打开我们的下载文件 打开后的界面&#xff0c;如图所示 我们选择custom&#xff0c;点击nex…...

前端使用pdf-lib库实现pdf合并,window.open预览合并后的pdf

最近出差开了好多发票&#xff0c;写了一个pdf合并网站&#xff0c;用于把多张发票pdf合并成一张&#xff0c;方便打印 使用pdf-lib这个库实现的pdf合并功能&#xff0c;预览使用的是浏览器自身查看pdf功能 源码 网页地址 https://zqy233.github.io/PDF-merge/ <!DOCTYPE h…...

计算机网络相关知识点总结(二)

比特bit是计算机中数据量的最小单位,可简记为b。字节Byte也是计算机中数据量的单位,可简记为B,1B8bit。常用的数据量单位还有kB、MB、GB、TB等,其中k、M、G、T的数值分别为 2 10 2^{10} 210, 2 20 2^{20} 220, 2 30 2^{30} 230, 2 40 2^{40} 240。 K, M, G, T 分别对应以下…...

Redmine与Gitlab整合(实战版)

网上查了很多文章&#xff0c;总结一下。 安装过程略。可参考&#xff1a;(84条消息) Redmine与Gitlab功能集成_redmine gitlab_羽之大公公的博客-CSDN博客 配置集成的方法&#xff0c;参考&#xff1a; Redmine与GitLab集成 (ngui.cc) 修改ssh-key密码的方法&#xff0c;参…...

(3)深度学习学习笔记-简单线性模型

文章目录 一、线性模型二、实例1.pytorch求导功能2.简单线性模型&#xff08;人工数据集&#xff09; 来源 一、线性模型 一个简单模型&#xff1a;假设一个房子的价格由卧室、卫生间、居住面积决定&#xff0c;用x1&#xff0c;x2&#xff0c;x3表示。 那么房价y就可以认为yw…...

pytorch3d 安装报错 RuntimeError: Not compiled with GPU support pytorch3d

安装环境 NVIDIA GeForce RTX 3090 cuda 11.3 python 3.8.5 torch 1.11.0 torchvision 0.12.0 环境安装命令 conda install pytorch1.11.0 torchvision0.12.0 torchaudio0.11.0 cudatoolkit11.3 -c pytorch安装pytorch3d参考官网链接 https://github.com/facebookresearch/p…...

spring工程的启动流程?bean的生命周期?提供哪些扩展点?管理事务?解决循环依赖问题的?事务传播行为有哪些?

1.Spring工程的启动流程&#xff1a; Spring工程的启动流程主要包括以下几个步骤&#xff1a; 加载配置文件&#xff1a;Spring会读取配置文件&#xff08;如XML配置文件或注解配置&#xff09;来获取应用程序的配置信息。实例化并初始化IoC容器&#xff1a;Spring会创建并初…...

使用 Zabbix 监控 RocketMQ列举监控项和触发器

在使用 Zabbix 监控 RocketMQ 的过程中&#xff0c;以下是一些可能的监控项和触发器&#xff1a; 监控项 集群总体健康状况生产者和消费者的连接数量Broker 的状态消息的生产和消费速度队列深度&#xff08;即队列中的消息数量&#xff09;磁盘空间使用内存使用CPU使用网络流…...

uniApp:路由与页面跳转及传参

方式一&#xff1a;声明式导航 声明式导航&#xff0c;通过组件进行跳转。官方文档&#xff1a;详情 使用 navigator 组件进行页面跳转。 属性类型默认值说明urlString应用内的跳转链接&#xff0c;值为相对路径或绝对路径&#xff0c;如&#xff1a;“…/first/first”&#x…...

C++初阶-list的底层

目录 1.std::list实现的所有代码 2.list的简单介绍 2.1实现list的类 2.2_list_iterator的实现 2.2.1_list_iterator实现的原因和好处 2.2.2_list_iterator实现 2.3_list_node的实现 2.3.1. 避免递归的模板依赖 2.3.2. 内存布局一致性 2.3.3. 类型安全的替代方案 2.3.…...

JavaScript 中的 ES|QL:利用 Apache Arrow 工具

作者&#xff1a;来自 Elastic Jeffrey Rengifo 学习如何将 ES|QL 与 JavaScript 的 Apache Arrow 客户端工具一起使用。 想获得 Elastic 认证吗&#xff1f;了解下一期 Elasticsearch Engineer 培训的时间吧&#xff01; Elasticsearch 拥有众多新功能&#xff0c;助你为自己…...

MFC内存泄露

1、泄露代码示例 void X::SetApplicationBtn() {CMFCRibbonApplicationButton* pBtn GetApplicationButton();// 获取 Ribbon Bar 指针// 创建自定义按钮CCustomRibbonAppButton* pCustomButton new CCustomRibbonAppButton();pCustomButton->SetImage(IDB_BITMAP_Jdp26)…...

【入坑系列】TiDB 强制索引在不同库下不生效问题

文章目录 背景SQL 优化情况线上SQL运行情况分析怀疑1:执行计划绑定问题?尝试:SHOW WARNINGS 查看警告探索 TiDB 的 USE_INDEX 写法Hint 不生效问题排查解决参考背景 项目中使用 TiDB 数据库,并对 SQL 进行优化了,添加了强制索引。 UAT 环境已经生效,但 PROD 环境强制索…...

解决Ubuntu22.04 VMware失败的问题 ubuntu入门之二十八

现象1 打开VMware失败 Ubuntu升级之后打开VMware上报需要安装vmmon和vmnet&#xff0c;点击确认后如下提示 最终上报fail 解决方法 内核升级导致&#xff0c;需要在新内核下重新下载编译安装 查看版本 $ vmware -v VMware Workstation 17.5.1 build-23298084$ lsb_release…...

条件运算符

C中的三目运算符&#xff08;也称条件运算符&#xff0c;英文&#xff1a;ternary operator&#xff09;是一种简洁的条件选择语句&#xff0c;语法如下&#xff1a; 条件表达式 ? 表达式1 : 表达式2• 如果“条件表达式”为true&#xff0c;则整个表达式的结果为“表达式1”…...

WordPress插件:AI多语言写作与智能配图、免费AI模型、SEO文章生成

厌倦手动写WordPress文章&#xff1f;AI自动生成&#xff0c;效率提升10倍&#xff01; 支持多语言、自动配图、定时发布&#xff0c;让内容创作更轻松&#xff01; AI内容生成 → 不想每天写文章&#xff1f;AI一键生成高质量内容&#xff01;多语言支持 → 跨境电商必备&am…...

C++八股 —— 单例模式

文章目录 1. 基本概念2. 设计要点3. 实现方式4. 详解懒汉模式 1. 基本概念 线程安全&#xff08;Thread Safety&#xff09; 线程安全是指在多线程环境下&#xff0c;某个函数、类或代码片段能够被多个线程同时调用时&#xff0c;仍能保证数据的一致性和逻辑的正确性&#xf…...

CMake控制VS2022项目文件分组

我们可以通过 CMake 控制源文件的组织结构,使它们在 VS 解决方案资源管理器中以“组”(Filter)的形式进行分类展示。 🎯 目标 通过 CMake 脚本将 .cpp、.h 等源文件分组显示在 Visual Studio 2022 的解决方案资源管理器中。 ✅ 支持的方法汇总(共4种) 方法描述是否推荐…...

HarmonyOS运动开发:如何用mpchart绘制运动配速图表

##鸿蒙核心技术##运动开发##Sensor Service Kit&#xff08;传感器服务&#xff09;# 前言 在运动类应用中&#xff0c;运动数据的可视化是提升用户体验的重要环节。通过直观的图表展示运动过程中的关键数据&#xff0c;如配速、距离、卡路里消耗等&#xff0c;用户可以更清晰…...