当前位置: 首页 > news >正文

视频与AI,与进程交互(二) pytorch 极简训练自己的数据集并识别

目标学习任务

检测出已经分割出的图像的分类

2 使用pytorch

pytorch 非常简单就可以做到训练和加载

2.1 准备数据

在这里插入图片描述
如上图所示,用来训练的文件放在了train中,验证的文件放在val中,train.txt 和 val.txt 分别放文件名称和分类类别,然后我们在代码中写名字就行

里面我就为了做一个例子,放了两种文件,1 是 卡宴保时捷,2 是工程车,如下图所示
在这里插入图片描述
train.txt 如下图所示
在这里插入图片描述
val.txt 也是同样如此

3 show me the code

3.1 装载数据类

新增一个loaddata.py 文件

import torch
import random
from PIL import Image
class LoadData(torch.utils.data.Dataset):def __init__(self, root, datatxt, transform=None, target_transform=None):super(LoadData, self).__init__()file_txt = open(datatxt,'r')imgs = []for line in file_txt:line = line.rstrip()words = line.split('|')imgs.append((words[0], words[1]))self.imgs = imgsself.root = rootself.transform = transformself.target_transform = target_transformdef __getitem__(self, index):random.shuffle(self.imgs)name, label = self.imgs[index]img = Image.open(self.root + name).convert('RGB')if self.transform is not None:img = self.transform(img)label = int(label)return img, labeldef __len__(self):return len(self.imgs)

LoadData 类是从torch.util.data.Dataset上继承下来的,需要一个transform类输入,实际上就是转化大小

3.2 网络类

定义一个网络类,只有两个输出

import torch.nn as nn
import torch.nn.functional as F
import torch.optim as optimclass Net(nn.Module):def __init__(self):super(Net, self).__init__()self.conv1 = nn.Conv2d(3, 16, 3)self.pool = nn.MaxPool2d((2, 2))self.pool1 = nn.MaxPool2d((2, 2))self.conv2 = nn.Conv2d(16, 32, 3)self.fc1 = nn.Linear(36*36*32, 120)self.fc2 = nn.Linear(120, 60)self.fc3 = nn.Linear(60, 2)def forward(self, x):x = self.pool(F.relu(self.conv1(x)))x = self.pool1(F.relu(self.conv2(x)))x = x.view(-1, 36*36*32)x = F.relu(self.fc1(x))x = F.relu(self.fc2(x))x = self.fc3(x)return x

3.3 主要流程

import torch
from PIL import Image
import torchvision.transforms as transforms
import matplotlib.pyplot as plt
import numpy as np
import torch.nn as nn
import torch.optim as optim
from loaddata import LoadData
from modelnet import Netdevice = torch.device("cuda:0" if torch.cuda.is_available() else "cpu")
print(device)classes = ['工程车','卡宴']
transform = transforms.Compose([transforms.Resize((152, 152)),transforms.ToTensor(),transforms.Normalize((0.5, 0.5, 0.5), (0.5, 0.5, 0.5))])
train_data=LoadData(root ='./data/train/',datatxt='./data/'+'train.txt',transform=transform)
test_data=LoadData(root ='./data/val/',datatxt='./data/'+'val.txt',transform=transform)
train_loader = torch.utils.data.DataLoader(dataset=train_data, batch_size=2, shuffle=True)
test_loader = torch.utils.data.DataLoader(dataset=test_data, batch_size=2)def imshow(img):img = img / 2 + 0.5     # unnormalizenpimg = img.numpy()plt.imshow(np.transpose(npimg, (1, 2, 0)))plt.show()net = Net()
criterion = nn.CrossEntropyLoss()
optimizer = optim.SGD(net.parameters(), lr=0.001, momentum=0.9)for epoch in range(10):running_loss = 0.0for i, data in enumerate(train_loader, 0):inputs, labels = dataoptimizer.zero_grad()outputs = net(inputs)loss = criterion(outputs, labels)loss.backward()optimizer.step()running_loss += loss.item()if i % 200 == 0:print('[%d, %5d] loss: %.3f' %(epoch + 1, i + 1, running_loss / 200))running_loss = 0.0print('Finished Training')PATH = './test.pth'
torch.save(net.state_dict(), PATH)net = Net()
net.load_state_dict(torch.load(PATH))correct = 0
total = 0
with torch.no_grad():for data in test_loader:images, labels = dataoutputs = net(images)_, predicted = torch.max(outputs.data, 1)total += labels.size(0)correct += (predicted == labels).sum().item()print('Accuracy of the network on the test images: %d %%' % (100 * correct / total))

在这里插入图片描述
如上图所示,epoch为5时精确度为80%,为10时精确度为100%,各位不要当真,这这是训练集里面的数据集做识别,并不是真的精确度。

3.4 识别代码

import torch
from PIL import Image
import torchvision.transforms as transforms
import matplotlib.pyplot as plt
import numpy as np
import torch.nn as nn
from modelnet import NetPATH = './test.pth'
transform = transforms.Compose([transforms.Resize((152, 152)),transforms.ToTensor(),transforms.Normalize((0.5, 0.5, 0.5), (0.5, 0.5, 0.5))])net = Net()
net.load_state_dict(torch.load(PATH))img = Image.open("./data/val/102.jpg").convert('RGB')
img = transform(img)
with torch.no_grad():outputs = net(img)_, predicted = torch.max(outputs.data, 1)print("the 102 img lable is ",predicted)

如下图所示,102 为卡宴识别为1 正确
在这里插入图片描述

后记

后面我们准备是从视频中传递过来图像进行分类,同时使用我们的工具VT解码视频后进行内存共享来生成图像,而不是从磁盘加载。要用到我们的c++ 解码工具,和pytorch进行交互
以下是第一篇文章:视频与AI,与进程交互(一)
VT 工具准备开源,端午节节后开出来

相关文章:

视频与AI,与进程交互(二) pytorch 极简训练自己的数据集并识别

目标学习任务 检测出已经分割出的图像的分类 2 使用pytorch pytorch 非常简单就可以做到训练和加载 2.1 准备数据 如上图所示,用来训练的文件放在了train中,验证的文件放在val中,train.txt 和 val.txt 分别放文件名称和分类类别&#xff…...

LLM - 第2版 ChatGLM2-6B (General Language Model) 的工程配置

欢迎关注我的CSDN:https://spike.blog.csdn.net/ 本文地址:https://blog.csdn.net/caroline_wendy/article/details/131445696 ChatGLM2-6B 是开源中英双语对话模型 ChatGLM-6B 的第二代版本,在保留了初代模型对话流畅、部署门槛较低等众多优…...

从0开始,手写MySQL事务

说在前面:从0开始,手写MySQL的学习价值 尼恩曾经指导过的一个7年经验小伙,凭借精通Mysql, 搞定月薪40K。 从0开始,手写一个MySQL的学习价值在于: 可以深入地理解MySQL的内部机制和原理,Mysql可谓是面试的…...

React中useState的setState方法请求了好多次

1、问题描述 最近在写react的时候碰到了一个很奇怪的问题。 可以看到那个getXXX()的方法一直不断的被调用,网页一直请求,根本停不下来了。 2、产生原因 要弄明白这个原因,首先要先了解一下react生命周期。 react是组件式的编程,一…...

【MYSQL基础】基础命令介绍

基础命令 MYSQL注释方式 -- 单行注释/* 多行注释 哈哈哈哈哈 哈哈哈哈 */连接数据库 mysql -u root -p12345678退出数据库连接 使用exit;命令可以退出连接 查询MYSQL版本 mysql> select version(); ----------- | version() | ----------- | 8.0.27 | ----------- 1…...

多元回归预测 | Matlab基于灰狼算法优化深度置信网络(GWO-DBN)的数据回归预测,matlab代码回归预测,多变量输入模型

文章目录 效果一览文章概述部分源码参考资料效果一览 文章概述 多元回归预测 | Matlab基于灰狼算法优化深度置信网络(GWO-DBN)的数据回归预测,matlab代码回归预测,多变量输入模型,matlab代码回归预测,多变量输入模型,多变量输入模型 评价指标包括:MAE、RMSE和R2等,代码质…...

校园wifi网页认证登录入口

很多校园wifi网页认证登录入口是1.1.1.1 连上校园网在浏览器写上http://1.1.1.1就进入了校园网 使 用 说 明 一、帐户余额 < 0.00元时&#xff0c;帐号被禁用&#xff0c;需追加网费。 二、在计算中心机房上机的用户&#xff0c;登录时请选择新建帐号时给您指定的NT域&…...

[SpringBoot]Spring Security框架

目录 关于Spring Security框架 Spring Security框架的依赖项 Spring Security框架的典型特征 关于Spring Security的配置 关于默认的登录页 关于请求的授权访问&#xff08;访问控制&#xff09; 使用自定义的账号登录 使用数据库中的账号登录 关于密码编码器 使用BCry…...

Unity 之 抖音小游戏本地数据最新存储方法分享

Unity 之 抖音小游戏本地数据最新存储方法分享 一、抖音小游戏文件存储系统背景二、文件存储系统的使用方法2.1 初始化2.1 创建目录2.3 存储数据2.4 删除目录/文件2.5 其他相关操作 三&#xff0c;小结 抖音小游戏是一种基于抖音平台开发的小型游戏&#xff0c;与传统的 APP 不…...

逍遥自在学C语言 | 函数初级到高级解析

前言 函数是C语言中的基本构建块之一&#xff0c;它允许我们将代码组织成可重用、模块化的单元。 本文将逐步介绍C语言函数的基础概念、参数传递、返回值、递归以及内联函数和匿名函数。 一、人物简介 第一位闪亮登场&#xff0c;有请今后会一直教我们C语言的老师 —— 自在…...

Elastic 推出 Elastic AI 助手

作者&#xff1a;Mike Nichols Elastic 推出了 Elastic AI Assistant&#xff0c;这是一款由 ESRE 提供支持的开放式、生成式 AI 助手&#xff0c;旨在使网络安全民主化并支持各种技能水平的用户。 最近发布的 Elasticsearch Relevance Engine™ (ESRE™) 提供了用于创建高度相…...

【数据库】MySQL安装(最新图文保姆级别超详细版本介绍)

1.总共两部分&#xff08;第二部可省略&#xff09; 安装mysql体验mysql环境变量配置 1.1安装mysql 1.输入官网地址https://www.mysql.com/ 下载完成后&#xff0c;我们双击打开我们的下载文件 打开后的界面&#xff0c;如图所示 我们选择custom&#xff0c;点击nex…...

前端使用pdf-lib库实现pdf合并,window.open预览合并后的pdf

最近出差开了好多发票&#xff0c;写了一个pdf合并网站&#xff0c;用于把多张发票pdf合并成一张&#xff0c;方便打印 使用pdf-lib这个库实现的pdf合并功能&#xff0c;预览使用的是浏览器自身查看pdf功能 源码 网页地址 https://zqy233.github.io/PDF-merge/ <!DOCTYPE h…...

计算机网络相关知识点总结(二)

比特bit是计算机中数据量的最小单位,可简记为b。字节Byte也是计算机中数据量的单位,可简记为B,1B8bit。常用的数据量单位还有kB、MB、GB、TB等,其中k、M、G、T的数值分别为 2 10 2^{10} 210, 2 20 2^{20} 220, 2 30 2^{30} 230, 2 40 2^{40} 240。 K, M, G, T 分别对应以下…...

Redmine与Gitlab整合(实战版)

网上查了很多文章&#xff0c;总结一下。 安装过程略。可参考&#xff1a;(84条消息) Redmine与Gitlab功能集成_redmine gitlab_羽之大公公的博客-CSDN博客 配置集成的方法&#xff0c;参考&#xff1a; Redmine与GitLab集成 (ngui.cc) 修改ssh-key密码的方法&#xff0c;参…...

(3)深度学习学习笔记-简单线性模型

文章目录 一、线性模型二、实例1.pytorch求导功能2.简单线性模型&#xff08;人工数据集&#xff09; 来源 一、线性模型 一个简单模型&#xff1a;假设一个房子的价格由卧室、卫生间、居住面积决定&#xff0c;用x1&#xff0c;x2&#xff0c;x3表示。 那么房价y就可以认为yw…...

pytorch3d 安装报错 RuntimeError: Not compiled with GPU support pytorch3d

安装环境 NVIDIA GeForce RTX 3090 cuda 11.3 python 3.8.5 torch 1.11.0 torchvision 0.12.0 环境安装命令 conda install pytorch1.11.0 torchvision0.12.0 torchaudio0.11.0 cudatoolkit11.3 -c pytorch安装pytorch3d参考官网链接 https://github.com/facebookresearch/p…...

spring工程的启动流程?bean的生命周期?提供哪些扩展点?管理事务?解决循环依赖问题的?事务传播行为有哪些?

1.Spring工程的启动流程&#xff1a; Spring工程的启动流程主要包括以下几个步骤&#xff1a; 加载配置文件&#xff1a;Spring会读取配置文件&#xff08;如XML配置文件或注解配置&#xff09;来获取应用程序的配置信息。实例化并初始化IoC容器&#xff1a;Spring会创建并初…...

使用 Zabbix 监控 RocketMQ列举监控项和触发器

在使用 Zabbix 监控 RocketMQ 的过程中&#xff0c;以下是一些可能的监控项和触发器&#xff1a; 监控项 集群总体健康状况生产者和消费者的连接数量Broker 的状态消息的生产和消费速度队列深度&#xff08;即队列中的消息数量&#xff09;磁盘空间使用内存使用CPU使用网络流…...

uniApp:路由与页面跳转及传参

方式一&#xff1a;声明式导航 声明式导航&#xff0c;通过组件进行跳转。官方文档&#xff1a;详情 使用 navigator 组件进行页面跳转。 属性类型默认值说明urlString应用内的跳转链接&#xff0c;值为相对路径或绝对路径&#xff0c;如&#xff1a;“…/first/first”&#x…...

5分钟搞定AutoHotkey脚本转EXE:Ahk2Exe终极编译指南

5分钟搞定AutoHotkey脚本转EXE&#xff1a;Ahk2Exe终极编译指南 【免费下载链接】Ahk2Exe Official AutoHotkey script compiler - written itself in AutoHotkey 项目地址: https://gitcode.com/gh_mirrors/ah/Ahk2Exe 想要将AutoHotkey脚本快速转换为独立的可执行文件…...

usearch的内存泄漏自动化测试:在CI中集成泄漏检测

usearch的内存泄漏自动化测试&#xff1a;在CI中集成泄漏检测 【免费下载链接】usearch Fastest Open-Source Search & Clustering engine for Vectors & &#x1f51c; Strings in C, C, Python, JavaScript, Rust, Java, Objective-C, Swift, C#, GoLang, and Wolf…...

PCIe金手指设计避坑指南:从硬件选型到PCB布局的5个关键细节

PCIe金手指设计避坑指南&#xff1a;从硬件选型到PCB布局的5个关键细节 在高速数字系统设计中&#xff0c;PCIe金手指接口的可靠性直接决定了扩展卡的识别成功率和数据传输稳定性。许多工程师在完成原理图设计和PCB布局后&#xff0c;常会遇到设备频繁识别失败、链路训练不通过…...

3步解锁音乐自由:NCMDump帮你破解网易云音乐NCM格式

3步解锁音乐自由&#xff1a;NCMDump帮你破解网易云音乐NCM格式 【免费下载链接】ncmdump 项目地址: https://gitcode.com/gh_mirrors/ncmd/ncmdump 还在为下载的网易云音乐只能在特定App里播放而烦恼吗&#xff1f;当你精心挑选的歌单无法在车载音响、运动手表或家庭音…...

Phi-4-mini-reasoning部署实操手册:supervisor服务管理与日志排查指南

Phi-4-mini-reasoning部署实操手册&#xff1a;supervisor服务管理与日志排查指南 1. 模型概述 Phi-4-mini-reasoning 是一个专注于推理任务的文本生成模型&#xff0c;特别适合处理数学题、逻辑题、多步分析和简洁结论输出。与通用聊天模型不同&#xff0c;它采用"题目…...

Zotero Citation插件开发指南:从环境适配到定制优化的全流程实践

Zotero Citation插件开发指南&#xff1a;从环境适配到定制优化的全流程实践 【免费下载链接】zotero-citation Make Zoteros citation in Word easier and clearer. 项目地址: https://gitcode.com/gh_mirrors/zo/zotero-citation 问题发现&#xff1a;学术写作中的引用…...

Wan2.2-I2V-A14B镜像效果展示:夕阳海滩10秒1080P高清视频生成作品集

Wan2.2-I2V-A14B镜像效果展示&#xff1a;夕阳海滩10秒1080P高清视频生成作品集 1. 惊艳的视频生成效果 想象一下&#xff0c;只需简单描述&#xff0c;就能让电脑自动生成一段夕阳下的海滩视频。Wan2.2-I2V-A14B镜像让这个想象成为现实&#xff0c;它能将文字描述转化为高清…...

终极指南:5分钟掌握Piper鼠标地图组件与SVG渲染核心技术

终极指南&#xff1a;5分钟掌握Piper鼠标地图组件与SVG渲染核心技术 【免费下载链接】piper GTK application to configure gaming devices 项目地址: https://gitcode.com/gh_mirrors/pip/piper Piper是一款功能强大的GTK应用程序&#xff0c;专为配置游戏设备而设计。…...

Graph Node高级配置:环境变量与配置文件详解

Graph Node高级配置&#xff1a;环境变量与配置文件详解 【免费下载链接】graph-node Graph Node indexes data from blockchains such as Ethereum and serves it over GraphQL 项目地址: https://gitcode.com/gh_mirrors/gr/graph-node Graph Node 作为区块链数据索引…...

第三章、CLion+GCC+OpenOCD构建STM32标准库开发环境:从零到调试的完整实践

1. 环境准备与工具链安装 搭建STM32标准库开发环境的第一步&#xff0c;就是准备好所有必要的工具。这里我们需要三个核心组件&#xff1a;CLion作为集成开发环境、arm-none-eabi-gcc作为编译器、OpenOCD作为调试器。这三个工具的组合&#xff0c;可以让我们在Windows平台上获得…...