视频与AI,与进程交互(二) pytorch 极简训练自己的数据集并识别
目标学习任务
检测出已经分割出的图像的分类
2 使用pytorch
pytorch 非常简单就可以做到训练和加载
2.1 准备数据
如上图所示,用来训练的文件放在了train中,验证的文件放在val中,train.txt 和 val.txt 分别放文件名称和分类类别,然后我们在代码中写名字就行
里面我就为了做一个例子,放了两种文件,1 是 卡宴保时捷,2 是工程车,如下图所示
train.txt 如下图所示
val.txt 也是同样如此
3 show me the code
3.1 装载数据类
新增一个loaddata.py 文件
import torch
import random
from PIL import Image
class LoadData(torch.utils.data.Dataset):def __init__(self, root, datatxt, transform=None, target_transform=None):super(LoadData, self).__init__()file_txt = open(datatxt,'r')imgs = []for line in file_txt:line = line.rstrip()words = line.split('|')imgs.append((words[0], words[1]))self.imgs = imgsself.root = rootself.transform = transformself.target_transform = target_transformdef __getitem__(self, index):random.shuffle(self.imgs)name, label = self.imgs[index]img = Image.open(self.root + name).convert('RGB')if self.transform is not None:img = self.transform(img)label = int(label)return img, labeldef __len__(self):return len(self.imgs)
LoadData 类是从torch.util.data.Dataset上继承下来的,需要一个transform类输入,实际上就是转化大小
3.2 网络类
定义一个网络类,只有两个输出
import torch.nn as nn
import torch.nn.functional as F
import torch.optim as optimclass Net(nn.Module):def __init__(self):super(Net, self).__init__()self.conv1 = nn.Conv2d(3, 16, 3)self.pool = nn.MaxPool2d((2, 2))self.pool1 = nn.MaxPool2d((2, 2))self.conv2 = nn.Conv2d(16, 32, 3)self.fc1 = nn.Linear(36*36*32, 120)self.fc2 = nn.Linear(120, 60)self.fc3 = nn.Linear(60, 2)def forward(self, x):x = self.pool(F.relu(self.conv1(x)))x = self.pool1(F.relu(self.conv2(x)))x = x.view(-1, 36*36*32)x = F.relu(self.fc1(x))x = F.relu(self.fc2(x))x = self.fc3(x)return x
3.3 主要流程
import torch
from PIL import Image
import torchvision.transforms as transforms
import matplotlib.pyplot as plt
import numpy as np
import torch.nn as nn
import torch.optim as optim
from loaddata import LoadData
from modelnet import Netdevice = torch.device("cuda:0" if torch.cuda.is_available() else "cpu")
print(device)classes = ['工程车','卡宴']
transform = transforms.Compose([transforms.Resize((152, 152)),transforms.ToTensor(),transforms.Normalize((0.5, 0.5, 0.5), (0.5, 0.5, 0.5))])
train_data=LoadData(root ='./data/train/',datatxt='./data/'+'train.txt',transform=transform)
test_data=LoadData(root ='./data/val/',datatxt='./data/'+'val.txt',transform=transform)
train_loader = torch.utils.data.DataLoader(dataset=train_data, batch_size=2, shuffle=True)
test_loader = torch.utils.data.DataLoader(dataset=test_data, batch_size=2)def imshow(img):img = img / 2 + 0.5 # unnormalizenpimg = img.numpy()plt.imshow(np.transpose(npimg, (1, 2, 0)))plt.show()net = Net()
criterion = nn.CrossEntropyLoss()
optimizer = optim.SGD(net.parameters(), lr=0.001, momentum=0.9)for epoch in range(10):running_loss = 0.0for i, data in enumerate(train_loader, 0):inputs, labels = dataoptimizer.zero_grad()outputs = net(inputs)loss = criterion(outputs, labels)loss.backward()optimizer.step()running_loss += loss.item()if i % 200 == 0:print('[%d, %5d] loss: %.3f' %(epoch + 1, i + 1, running_loss / 200))running_loss = 0.0print('Finished Training')PATH = './test.pth'
torch.save(net.state_dict(), PATH)net = Net()
net.load_state_dict(torch.load(PATH))correct = 0
total = 0
with torch.no_grad():for data in test_loader:images, labels = dataoutputs = net(images)_, predicted = torch.max(outputs.data, 1)total += labels.size(0)correct += (predicted == labels).sum().item()print('Accuracy of the network on the test images: %d %%' % (100 * correct / total))
如上图所示,epoch为5时精确度为80%,为10时精确度为100%,各位不要当真,这这是训练集里面的数据集做识别,并不是真的精确度。
3.4 识别代码
import torch
from PIL import Image
import torchvision.transforms as transforms
import matplotlib.pyplot as plt
import numpy as np
import torch.nn as nn
from modelnet import NetPATH = './test.pth'
transform = transforms.Compose([transforms.Resize((152, 152)),transforms.ToTensor(),transforms.Normalize((0.5, 0.5, 0.5), (0.5, 0.5, 0.5))])net = Net()
net.load_state_dict(torch.load(PATH))img = Image.open("./data/val/102.jpg").convert('RGB')
img = transform(img)
with torch.no_grad():outputs = net(img)_, predicted = torch.max(outputs.data, 1)print("the 102 img lable is ",predicted)
如下图所示,102 为卡宴识别为1 正确
后记
后面我们准备是从视频中传递过来图像进行分类,同时使用我们的工具VT解码视频后进行内存共享来生成图像,而不是从磁盘加载。要用到我们的c++ 解码工具,和pytorch进行交互
以下是第一篇文章:视频与AI,与进程交互(一)
VT 工具准备开源,端午节节后开出来
相关文章:

视频与AI,与进程交互(二) pytorch 极简训练自己的数据集并识别
目标学习任务 检测出已经分割出的图像的分类 2 使用pytorch pytorch 非常简单就可以做到训练和加载 2.1 准备数据 如上图所示,用来训练的文件放在了train中,验证的文件放在val中,train.txt 和 val.txt 分别放文件名称和分类类别ÿ…...

LLM - 第2版 ChatGLM2-6B (General Language Model) 的工程配置
欢迎关注我的CSDN:https://spike.blog.csdn.net/ 本文地址:https://blog.csdn.net/caroline_wendy/article/details/131445696 ChatGLM2-6B 是开源中英双语对话模型 ChatGLM-6B 的第二代版本,在保留了初代模型对话流畅、部署门槛较低等众多优…...

从0开始,手写MySQL事务
说在前面:从0开始,手写MySQL的学习价值 尼恩曾经指导过的一个7年经验小伙,凭借精通Mysql, 搞定月薪40K。 从0开始,手写一个MySQL的学习价值在于: 可以深入地理解MySQL的内部机制和原理,Mysql可谓是面试的…...

React中useState的setState方法请求了好多次
1、问题描述 最近在写react的时候碰到了一个很奇怪的问题。 可以看到那个getXXX()的方法一直不断的被调用,网页一直请求,根本停不下来了。 2、产生原因 要弄明白这个原因,首先要先了解一下react生命周期。 react是组件式的编程,一…...

【MYSQL基础】基础命令介绍
基础命令 MYSQL注释方式 -- 单行注释/* 多行注释 哈哈哈哈哈 哈哈哈哈 */连接数据库 mysql -u root -p12345678退出数据库连接 使用exit;命令可以退出连接 查询MYSQL版本 mysql> select version(); ----------- | version() | ----------- | 8.0.27 | ----------- 1…...

多元回归预测 | Matlab基于灰狼算法优化深度置信网络(GWO-DBN)的数据回归预测,matlab代码回归预测,多变量输入模型
文章目录 效果一览文章概述部分源码参考资料效果一览 文章概述 多元回归预测 | Matlab基于灰狼算法优化深度置信网络(GWO-DBN)的数据回归预测,matlab代码回归预测,多变量输入模型,matlab代码回归预测,多变量输入模型,多变量输入模型 评价指标包括:MAE、RMSE和R2等,代码质…...

校园wifi网页认证登录入口
很多校园wifi网页认证登录入口是1.1.1.1 连上校园网在浏览器写上http://1.1.1.1就进入了校园网 使 用 说 明 一、帐户余额 < 0.00元时,帐号被禁用,需追加网费。 二、在计算中心机房上机的用户,登录时请选择新建帐号时给您指定的NT域&…...

[SpringBoot]Spring Security框架
目录 关于Spring Security框架 Spring Security框架的依赖项 Spring Security框架的典型特征 关于Spring Security的配置 关于默认的登录页 关于请求的授权访问(访问控制) 使用自定义的账号登录 使用数据库中的账号登录 关于密码编码器 使用BCry…...

Unity 之 抖音小游戏本地数据最新存储方法分享
Unity 之 抖音小游戏本地数据最新存储方法分享 一、抖音小游戏文件存储系统背景二、文件存储系统的使用方法2.1 初始化2.1 创建目录2.3 存储数据2.4 删除目录/文件2.5 其他相关操作 三,小结 抖音小游戏是一种基于抖音平台开发的小型游戏,与传统的 APP 不…...

逍遥自在学C语言 | 函数初级到高级解析
前言 函数是C语言中的基本构建块之一,它允许我们将代码组织成可重用、模块化的单元。 本文将逐步介绍C语言函数的基础概念、参数传递、返回值、递归以及内联函数和匿名函数。 一、人物简介 第一位闪亮登场,有请今后会一直教我们C语言的老师 —— 自在…...

Elastic 推出 Elastic AI 助手
作者:Mike Nichols Elastic 推出了 Elastic AI Assistant,这是一款由 ESRE 提供支持的开放式、生成式 AI 助手,旨在使网络安全民主化并支持各种技能水平的用户。 最近发布的 Elasticsearch Relevance Engine™ (ESRE™) 提供了用于创建高度相…...

【数据库】MySQL安装(最新图文保姆级别超详细版本介绍)
1.总共两部分(第二部可省略) 安装mysql体验mysql环境变量配置 1.1安装mysql 1.输入官网地址https://www.mysql.com/ 下载完成后,我们双击打开我们的下载文件 打开后的界面,如图所示 我们选择custom,点击nex…...

前端使用pdf-lib库实现pdf合并,window.open预览合并后的pdf
最近出差开了好多发票,写了一个pdf合并网站,用于把多张发票pdf合并成一张,方便打印 使用pdf-lib这个库实现的pdf合并功能,预览使用的是浏览器自身查看pdf功能 源码 网页地址 https://zqy233.github.io/PDF-merge/ <!DOCTYPE h…...

计算机网络相关知识点总结(二)
比特bit是计算机中数据量的最小单位,可简记为b。字节Byte也是计算机中数据量的单位,可简记为B,1B8bit。常用的数据量单位还有kB、MB、GB、TB等,其中k、M、G、T的数值分别为 2 10 2^{10} 210, 2 20 2^{20} 220, 2 30 2^{30} 230, 2 40 2^{40} 240。 K, M, G, T 分别对应以下…...

Redmine与Gitlab整合(实战版)
网上查了很多文章,总结一下。 安装过程略。可参考:(84条消息) Redmine与Gitlab功能集成_redmine gitlab_羽之大公公的博客-CSDN博客 配置集成的方法,参考: Redmine与GitLab集成 (ngui.cc) 修改ssh-key密码的方法,参…...

(3)深度学习学习笔记-简单线性模型
文章目录 一、线性模型二、实例1.pytorch求导功能2.简单线性模型(人工数据集) 来源 一、线性模型 一个简单模型:假设一个房子的价格由卧室、卫生间、居住面积决定,用x1,x2,x3表示。 那么房价y就可以认为yw…...

pytorch3d 安装报错 RuntimeError: Not compiled with GPU support pytorch3d
安装环境 NVIDIA GeForce RTX 3090 cuda 11.3 python 3.8.5 torch 1.11.0 torchvision 0.12.0 环境安装命令 conda install pytorch1.11.0 torchvision0.12.0 torchaudio0.11.0 cudatoolkit11.3 -c pytorch安装pytorch3d参考官网链接 https://github.com/facebookresearch/p…...

spring工程的启动流程?bean的生命周期?提供哪些扩展点?管理事务?解决循环依赖问题的?事务传播行为有哪些?
1.Spring工程的启动流程: Spring工程的启动流程主要包括以下几个步骤: 加载配置文件:Spring会读取配置文件(如XML配置文件或注解配置)来获取应用程序的配置信息。实例化并初始化IoC容器:Spring会创建并初…...

使用 Zabbix 监控 RocketMQ列举监控项和触发器
在使用 Zabbix 监控 RocketMQ 的过程中,以下是一些可能的监控项和触发器: 监控项 集群总体健康状况生产者和消费者的连接数量Broker 的状态消息的生产和消费速度队列深度(即队列中的消息数量)磁盘空间使用内存使用CPU使用网络流…...

uniApp:路由与页面跳转及传参
方式一:声明式导航 声明式导航,通过组件进行跳转。官方文档:详情 使用 navigator 组件进行页面跳转。 属性类型默认值说明urlString应用内的跳转链接,值为相对路径或绝对路径,如:“…/first/first”&#x…...

Java中操作文件(二)
目录 一、什么是数据流 二、InputStream概述 2.1、方法 2.2、说明 三、FileInputStream概述 3.1、构造方法 3.2、利用Scanner进行字符串读取,简化操作 四、OutputStream概述 4.1、方法 4.2、PrinterWriter简化写操作 五、小程序练习 示例1 示例…...

springboot+vue在线考试系统(java项目源码+文档)
风定落花生,歌声逐流水,大家好我是风歌,混迹在java圈的辛苦码农。今天要和大家聊的是一款基于springboot的在线考试系统。项目源码以及部署相关请联系风歌,文末附上联系信息 。 💕💕作者:风歌&a…...

样式方案:在 Vite 中接入现代化的 CSS 工程化方案
上一小节,我们使用 Vite 初始化了一个 Web 项目,迈出了使用 Vite 的第一步。但在实际工作中,仅用 Vite 官方的脚手架项目是不够的,往往还需要考虑诸多的工程化因素,借助 Vite 本身的配置以及业界的各种生态,…...

C#获取根目录实现方法汇总
以下是C#获取不同类型项目根目录的实现方法汇总,以及在 .NET Core 中获取项目根目录的方法: 控制台应用程序 string rootPath Environment.CurrentDirectory; string rootPath AppDomain.CurrentDomain.BaseDirectory; string rootPath Path.GetFul…...

vue获取当前坐标并通过天地图逆转码为省市区
因为需求需要获取用户当前的地理位置用于分析 通过原生的navigator.geolocation.getCurrentPosition获取经纬度 这个方法在谷歌浏览器会失效(原因未知),目前ie浏览器是可以获取的 getCurrentPosition() {if (navigator.geolocation) {var o…...

【MySQL】事务及其隔离性/隔离级别
目录 一、事务的概念 1、事务的四种特性 2、事务的作用 3、存储引擎对事务的支持 4、事务的提交方式 二、事务的启动、回滚与提交 1、准备工作:调整MySQL的默认隔离级别为最低/创建测试表 2、事务的启动、回滚与提交 3、启动事务后未commit,但是…...

计算机由于找不到d3dx9_35.dll,无法启动软件游戏的三个修复方法
在打开游戏的时候,计算机提示由于找不到d3dx9_35.dll,无法正常启动运行。这个是为什么呢?d3dx9_35.dll是DirectX 9.0里面的一个动态连结库文件,它包含了Direct3D、DirectPlay几个组件的二进制文件,为软件提供了多媒体图…...

第三章 模型篇:模型与模型的搭建
写在前面的话 这部分只解释代码,不对线性层(全连接层),卷积层等layer的原理进行解释。 尽量写的比较全了,但是自身水平有限,不太确定是否有遗漏重要的部分。 教程参考: https://pytorch.org/tutorials/ https://githu…...

深度学习一些简单概念的整理笔记
大概看了一点动手学深度学习,简单整理一些概念。 一些问题 测试结果 Precision-Recall曲线定性分析模型精度average precision(AP) 平均精度 Precision :检索出来的条目中有多大比例是我们需要的。 一些概念 损失函数(loss function&…...

Vue3中引入Element-plus
安装 npm install element-plus --save完整引入 打包后体积很大,适合学习,不适合生产。 此方法对于 vite 和 cli 脚手架创建的vue3均适用 // main.ts import { createApp } from vue //引入element-plus import ElementPlus from element-plus //引入…...