视频与AI,与进程交互(二) pytorch 极简训练自己的数据集并识别
目标学习任务
检测出已经分割出的图像的分类
2 使用pytorch
pytorch 非常简单就可以做到训练和加载
2.1 准备数据
如上图所示,用来训练的文件放在了train中,验证的文件放在val中,train.txt 和 val.txt 分别放文件名称和分类类别,然后我们在代码中写名字就行
里面我就为了做一个例子,放了两种文件,1 是 卡宴保时捷,2 是工程车,如下图所示
train.txt 如下图所示
val.txt 也是同样如此
3 show me the code
3.1 装载数据类
新增一个loaddata.py 文件
import torch
import random
from PIL import Image
class LoadData(torch.utils.data.Dataset):def __init__(self, root, datatxt, transform=None, target_transform=None):super(LoadData, self).__init__()file_txt = open(datatxt,'r')imgs = []for line in file_txt:line = line.rstrip()words = line.split('|')imgs.append((words[0], words[1]))self.imgs = imgsself.root = rootself.transform = transformself.target_transform = target_transformdef __getitem__(self, index):random.shuffle(self.imgs)name, label = self.imgs[index]img = Image.open(self.root + name).convert('RGB')if self.transform is not None:img = self.transform(img)label = int(label)return img, labeldef __len__(self):return len(self.imgs)
LoadData 类是从torch.util.data.Dataset上继承下来的,需要一个transform类输入,实际上就是转化大小
3.2 网络类
定义一个网络类,只有两个输出
import torch.nn as nn
import torch.nn.functional as F
import torch.optim as optimclass Net(nn.Module):def __init__(self):super(Net, self).__init__()self.conv1 = nn.Conv2d(3, 16, 3)self.pool = nn.MaxPool2d((2, 2))self.pool1 = nn.MaxPool2d((2, 2))self.conv2 = nn.Conv2d(16, 32, 3)self.fc1 = nn.Linear(36*36*32, 120)self.fc2 = nn.Linear(120, 60)self.fc3 = nn.Linear(60, 2)def forward(self, x):x = self.pool(F.relu(self.conv1(x)))x = self.pool1(F.relu(self.conv2(x)))x = x.view(-1, 36*36*32)x = F.relu(self.fc1(x))x = F.relu(self.fc2(x))x = self.fc3(x)return x
3.3 主要流程
import torch
from PIL import Image
import torchvision.transforms as transforms
import matplotlib.pyplot as plt
import numpy as np
import torch.nn as nn
import torch.optim as optim
from loaddata import LoadData
from modelnet import Netdevice = torch.device("cuda:0" if torch.cuda.is_available() else "cpu")
print(device)classes = ['工程车','卡宴']
transform = transforms.Compose([transforms.Resize((152, 152)),transforms.ToTensor(),transforms.Normalize((0.5, 0.5, 0.5), (0.5, 0.5, 0.5))])
train_data=LoadData(root ='./data/train/',datatxt='./data/'+'train.txt',transform=transform)
test_data=LoadData(root ='./data/val/',datatxt='./data/'+'val.txt',transform=transform)
train_loader = torch.utils.data.DataLoader(dataset=train_data, batch_size=2, shuffle=True)
test_loader = torch.utils.data.DataLoader(dataset=test_data, batch_size=2)def imshow(img):img = img / 2 + 0.5 # unnormalizenpimg = img.numpy()plt.imshow(np.transpose(npimg, (1, 2, 0)))plt.show()net = Net()
criterion = nn.CrossEntropyLoss()
optimizer = optim.SGD(net.parameters(), lr=0.001, momentum=0.9)for epoch in range(10):running_loss = 0.0for i, data in enumerate(train_loader, 0):inputs, labels = dataoptimizer.zero_grad()outputs = net(inputs)loss = criterion(outputs, labels)loss.backward()optimizer.step()running_loss += loss.item()if i % 200 == 0:print('[%d, %5d] loss: %.3f' %(epoch + 1, i + 1, running_loss / 200))running_loss = 0.0print('Finished Training')PATH = './test.pth'
torch.save(net.state_dict(), PATH)net = Net()
net.load_state_dict(torch.load(PATH))correct = 0
total = 0
with torch.no_grad():for data in test_loader:images, labels = dataoutputs = net(images)_, predicted = torch.max(outputs.data, 1)total += labels.size(0)correct += (predicted == labels).sum().item()print('Accuracy of the network on the test images: %d %%' % (100 * correct / total))
如上图所示,epoch为5时精确度为80%,为10时精确度为100%,各位不要当真,这这是训练集里面的数据集做识别,并不是真的精确度。
3.4 识别代码
import torch
from PIL import Image
import torchvision.transforms as transforms
import matplotlib.pyplot as plt
import numpy as np
import torch.nn as nn
from modelnet import NetPATH = './test.pth'
transform = transforms.Compose([transforms.Resize((152, 152)),transforms.ToTensor(),transforms.Normalize((0.5, 0.5, 0.5), (0.5, 0.5, 0.5))])net = Net()
net.load_state_dict(torch.load(PATH))img = Image.open("./data/val/102.jpg").convert('RGB')
img = transform(img)
with torch.no_grad():outputs = net(img)_, predicted = torch.max(outputs.data, 1)print("the 102 img lable is ",predicted)
如下图所示,102 为卡宴识别为1 正确
后记
后面我们准备是从视频中传递过来图像进行分类,同时使用我们的工具VT解码视频后进行内存共享来生成图像,而不是从磁盘加载。要用到我们的c++ 解码工具,和pytorch进行交互
以下是第一篇文章:视频与AI,与进程交互(一)
VT 工具准备开源,端午节节后开出来
相关文章:

视频与AI,与进程交互(二) pytorch 极简训练自己的数据集并识别
目标学习任务 检测出已经分割出的图像的分类 2 使用pytorch pytorch 非常简单就可以做到训练和加载 2.1 准备数据 如上图所示,用来训练的文件放在了train中,验证的文件放在val中,train.txt 和 val.txt 分别放文件名称和分类类别ÿ…...

LLM - 第2版 ChatGLM2-6B (General Language Model) 的工程配置
欢迎关注我的CSDN:https://spike.blog.csdn.net/ 本文地址:https://blog.csdn.net/caroline_wendy/article/details/131445696 ChatGLM2-6B 是开源中英双语对话模型 ChatGLM-6B 的第二代版本,在保留了初代模型对话流畅、部署门槛较低等众多优…...

从0开始,手写MySQL事务
说在前面:从0开始,手写MySQL的学习价值 尼恩曾经指导过的一个7年经验小伙,凭借精通Mysql, 搞定月薪40K。 从0开始,手写一个MySQL的学习价值在于: 可以深入地理解MySQL的内部机制和原理,Mysql可谓是面试的…...

React中useState的setState方法请求了好多次
1、问题描述 最近在写react的时候碰到了一个很奇怪的问题。 可以看到那个getXXX()的方法一直不断的被调用,网页一直请求,根本停不下来了。 2、产生原因 要弄明白这个原因,首先要先了解一下react生命周期。 react是组件式的编程,一…...

【MYSQL基础】基础命令介绍
基础命令 MYSQL注释方式 -- 单行注释/* 多行注释 哈哈哈哈哈 哈哈哈哈 */连接数据库 mysql -u root -p12345678退出数据库连接 使用exit;命令可以退出连接 查询MYSQL版本 mysql> select version(); ----------- | version() | ----------- | 8.0.27 | ----------- 1…...

多元回归预测 | Matlab基于灰狼算法优化深度置信网络(GWO-DBN)的数据回归预测,matlab代码回归预测,多变量输入模型
文章目录 效果一览文章概述部分源码参考资料效果一览 文章概述 多元回归预测 | Matlab基于灰狼算法优化深度置信网络(GWO-DBN)的数据回归预测,matlab代码回归预测,多变量输入模型,matlab代码回归预测,多变量输入模型,多变量输入模型 评价指标包括:MAE、RMSE和R2等,代码质…...

校园wifi网页认证登录入口
很多校园wifi网页认证登录入口是1.1.1.1 连上校园网在浏览器写上http://1.1.1.1就进入了校园网 使 用 说 明 一、帐户余额 < 0.00元时,帐号被禁用,需追加网费。 二、在计算中心机房上机的用户,登录时请选择新建帐号时给您指定的NT域&…...

[SpringBoot]Spring Security框架
目录 关于Spring Security框架 Spring Security框架的依赖项 Spring Security框架的典型特征 关于Spring Security的配置 关于默认的登录页 关于请求的授权访问(访问控制) 使用自定义的账号登录 使用数据库中的账号登录 关于密码编码器 使用BCry…...

Unity 之 抖音小游戏本地数据最新存储方法分享
Unity 之 抖音小游戏本地数据最新存储方法分享 一、抖音小游戏文件存储系统背景二、文件存储系统的使用方法2.1 初始化2.1 创建目录2.3 存储数据2.4 删除目录/文件2.5 其他相关操作 三,小结 抖音小游戏是一种基于抖音平台开发的小型游戏,与传统的 APP 不…...

逍遥自在学C语言 | 函数初级到高级解析
前言 函数是C语言中的基本构建块之一,它允许我们将代码组织成可重用、模块化的单元。 本文将逐步介绍C语言函数的基础概念、参数传递、返回值、递归以及内联函数和匿名函数。 一、人物简介 第一位闪亮登场,有请今后会一直教我们C语言的老师 —— 自在…...

Elastic 推出 Elastic AI 助手
作者:Mike Nichols Elastic 推出了 Elastic AI Assistant,这是一款由 ESRE 提供支持的开放式、生成式 AI 助手,旨在使网络安全民主化并支持各种技能水平的用户。 最近发布的 Elasticsearch Relevance Engine™ (ESRE™) 提供了用于创建高度相…...

【数据库】MySQL安装(最新图文保姆级别超详细版本介绍)
1.总共两部分(第二部可省略) 安装mysql体验mysql环境变量配置 1.1安装mysql 1.输入官网地址https://www.mysql.com/ 下载完成后,我们双击打开我们的下载文件 打开后的界面,如图所示 我们选择custom,点击nex…...
前端使用pdf-lib库实现pdf合并,window.open预览合并后的pdf
最近出差开了好多发票,写了一个pdf合并网站,用于把多张发票pdf合并成一张,方便打印 使用pdf-lib这个库实现的pdf合并功能,预览使用的是浏览器自身查看pdf功能 源码 网页地址 https://zqy233.github.io/PDF-merge/ <!DOCTYPE h…...
计算机网络相关知识点总结(二)
比特bit是计算机中数据量的最小单位,可简记为b。字节Byte也是计算机中数据量的单位,可简记为B,1B8bit。常用的数据量单位还有kB、MB、GB、TB等,其中k、M、G、T的数值分别为 2 10 2^{10} 210, 2 20 2^{20} 220, 2 30 2^{30} 230, 2 40 2^{40} 240。 K, M, G, T 分别对应以下…...
Redmine与Gitlab整合(实战版)
网上查了很多文章,总结一下。 安装过程略。可参考:(84条消息) Redmine与Gitlab功能集成_redmine gitlab_羽之大公公的博客-CSDN博客 配置集成的方法,参考: Redmine与GitLab集成 (ngui.cc) 修改ssh-key密码的方法,参…...

(3)深度学习学习笔记-简单线性模型
文章目录 一、线性模型二、实例1.pytorch求导功能2.简单线性模型(人工数据集) 来源 一、线性模型 一个简单模型:假设一个房子的价格由卧室、卫生间、居住面积决定,用x1,x2,x3表示。 那么房价y就可以认为yw…...

pytorch3d 安装报错 RuntimeError: Not compiled with GPU support pytorch3d
安装环境 NVIDIA GeForce RTX 3090 cuda 11.3 python 3.8.5 torch 1.11.0 torchvision 0.12.0 环境安装命令 conda install pytorch1.11.0 torchvision0.12.0 torchaudio0.11.0 cudatoolkit11.3 -c pytorch安装pytorch3d参考官网链接 https://github.com/facebookresearch/p…...

spring工程的启动流程?bean的生命周期?提供哪些扩展点?管理事务?解决循环依赖问题的?事务传播行为有哪些?
1.Spring工程的启动流程: Spring工程的启动流程主要包括以下几个步骤: 加载配置文件:Spring会读取配置文件(如XML配置文件或注解配置)来获取应用程序的配置信息。实例化并初始化IoC容器:Spring会创建并初…...
使用 Zabbix 监控 RocketMQ列举监控项和触发器
在使用 Zabbix 监控 RocketMQ 的过程中,以下是一些可能的监控项和触发器: 监控项 集群总体健康状况生产者和消费者的连接数量Broker 的状态消息的生产和消费速度队列深度(即队列中的消息数量)磁盘空间使用内存使用CPU使用网络流…...
uniApp:路由与页面跳转及传参
方式一:声明式导航 声明式导航,通过组件进行跳转。官方文档:详情 使用 navigator 组件进行页面跳转。 属性类型默认值说明urlString应用内的跳转链接,值为相对路径或绝对路径,如:“…/first/first”&#x…...
spring:实例工厂方法获取bean
spring处理使用静态工厂方法获取bean实例,也可以通过实例工厂方法获取bean实例。 实例工厂方法步骤如下: 定义实例工厂类(Java代码),定义实例工厂(xml),定义调用实例工厂ÿ…...
鸿蒙中用HarmonyOS SDK应用服务 HarmonyOS5开发一个生活电费的缴纳和查询小程序
一、项目初始化与配置 1. 创建项目 ohpm init harmony/utility-payment-app 2. 配置权限 // module.json5 {"requestPermissions": [{"name": "ohos.permission.INTERNET"},{"name": "ohos.permission.GET_NETWORK_INFO"…...
rnn判断string中第一次出现a的下标
# coding:utf8 import torch import torch.nn as nn import numpy as np import random import json""" 基于pytorch的网络编写 实现一个RNN网络完成多分类任务 判断字符 a 第一次出现在字符串中的位置 """class TorchModel(nn.Module):def __in…...
Java求职者面试指南:Spring、Spring Boot、MyBatis框架与计算机基础问题解析
Java求职者面试指南:Spring、Spring Boot、MyBatis框架与计算机基础问题解析 一、第一轮提问(基础概念问题) 1. 请解释Spring框架的核心容器是什么?它在Spring中起到什么作用? Spring框架的核心容器是IoC容器&#…...
JavaScript 数据类型详解
JavaScript 数据类型详解 JavaScript 数据类型分为 原始类型(Primitive) 和 对象类型(Object) 两大类,共 8 种(ES11): 一、原始类型(7种) 1. undefined 定…...

【网络安全】开源系统getshell漏洞挖掘
审计过程: 在入口文件admin/index.php中: 用户可以通过m,c,a等参数控制加载的文件和方法,在app/system/entrance.php中存在重点代码: 当M_TYPE system并且M_MODULE include时,会设置常量PATH_OWN_FILE为PATH_APP.M_T…...
python打卡第47天
昨天代码中注意力热图的部分顺移至今天 知识点回顾: 热力图 作业:对比不同卷积层热图可视化的结果 def visualize_attention_map(model, test_loader, device, class_names, num_samples3):"""可视化模型的注意力热力图,展示模…...

【版本控制】GitHub Desktop 入门教程与开源协作全流程解析
目录 0 引言1 GitHub Desktop 入门教程1.1 安装与基础配置1.2 核心功能使用指南仓库管理日常开发流程分支管理 2 GitHub 开源协作流程详解2.1 Fork & Pull Request 模型2.2 完整协作流程步骤步骤 1: Fork(创建个人副本)步骤 2: Clone(克隆…...

【Zephyr 系列 16】构建 BLE + LoRa 协同通信系统:网关转发与混合调度实战
🧠关键词:Zephyr、BLE、LoRa、混合通信、事件驱动、网关中继、低功耗调度 📌面向读者:希望将 BLE 和 LoRa 结合应用于资产追踪、环境监测、远程数据采集等场景的开发者 📊篇幅预计:5300+ 字 🧭 背景与需求 在许多 IoT 项目中,单一通信方式往往难以兼顾近场数据采集…...

多模态大语言模型arxiv论文略读(110)
CoVLA: Comprehensive Vision-Language-Action Dataset for Autonomous Driving ➡️ 论文标题:CoVLA: Comprehensive Vision-Language-Action Dataset for Autonomous Driving ➡️ 论文作者:Hidehisa Arai, Keita Miwa, Kento Sasaki, Yu Yamaguchi, …...