当前位置: 首页 > news >正文

pytorch实现水果2分类(蓝莓,苹果)

1.数据集的路径,结构

dataset.py

目的:

        输入:没有输入,路径是写死了的。

        输出:返回的是一个对象,里面有self.data。self.data是一个列表,里面是(图片路径.jpg,标签)

        -data[item]返回的是(img_tensor , one-hot编码)。one-hot编码是[0,1]或者[1,0]

import glob
import os.pathimport cv2
import torch
from torch.utils.data import Dataset
from torchvision import transformsclass DtataAndLabel(Dataset):def __init__(self,path='fruits',is_train=True):self.tran=transforms.Compose([transforms.ToTensor(),transforms.Resize(size=(88,88))])is_train='train' if True else 'test'self.data=[]path=os.path.join(path,is_train)print('path=',path)print(os.path.join(path, '*', '*'))img_paths=glob.glob(os.path.join(path,'*','*'))for img_path in img_paths:label=0 if img_path.split('\\')[-2]=='blueberry' else 1self.data.append((img_path,label))def __getitem__(self, idx):#每一张图片返回一个img_tensor,one_hotimg_path,label =self.data[idx]img=cv2.imread(img_path)# img_gray=cv2.cvtColor(img,cv2.COLOR_BGR2GRAY)img_tensor=self.tran(img)img_tensor=img_tensor/255img_tensor=torch.flatten(img_tensor)one_hot=torch.zeros(2)one_hot[label]=1return img_tensor,one_hotdef __len__(self):return len(self.data)if __name__ == '__main__':# 测试data=DtataAndLabel()print(data[1][0].shape)print(data[1][1])

net.py

目的:将输入维度(k(k是加载进去的图片数),88,88,3)三通道的宽高是88,88,通过网络变化为(k,2)。

import torch.nn
import torch.nn as nnclass Net(nn.Module):def __init__(self):super().__init__()self.model = nn.Sequential(nn.Linear(88*88*3, 800),nn.ReLU(),nn.Linear(800, 500),nn.ReLU(),nn.Linear(500, 800),nn.ReLU(),nn.Linear(800, 200),nn.ReLU(),nn.Linear(200, 2),)self.softmax=nn.Softmax(dim=1)def forward(self,x):x=self.model(x)x=self.softmax(x)return x
if __name__ == '__main__':net=Net()#测试一下x=torch.randn(1,100*100)out=net(x)print(out.shape)

test_train.py

目的:将图像丢进模型,然后训练出最优模型

步骤:

       1.定义初始化

                -定义拿到data对象

                -定义加载器分批加载,这里可以变换维度

                -定义初始化网络

                -定义损失函数,这里采用了均方差函数

                -定义优化器

        2.实现训练

                -将每一批数据丢给网络,此时维度发生了变化,产生了升维

                -使用优化器        

                        ---自动梯度清0

                        ---自动求导更新参数

                -计算损失值和准确度

        ·~自己建一个文件夹

import torch.optim
from torch import nn
from torch.utils.data import DataLoader
from torch.utils.tensorboard import SummaryWriter
from tqdm import tqdmfrom net import Net
from dataset import DtataAndLabel
import torch.nn as nn
class TrainAndTest():def __init__(self):self.writer = SummaryWriter("logs")self.train_data=DtataAndLabel(is_train=True)self.test_data=DtataAndLabel(is_train=False)#使用加载器分批加载self.train_loader=DataLoader(self.train_data,batch_size=10,shuffle=True)self.test_loader=DataLoader(self.test_data,batch_size=10,shuffle=True)#初始化网络#损失函数#优化器net=Net()self.net=netself.loss=nn.MSELoss()self.opt=torch.optim.Adam(net.parameters(),lr=0.001)self.min_loss=100.0self.weight_path='weight/best.pt'def train(self,epoch):sum_loss = 0sum_acc = 0for img_tensors, targets in tqdm(self.train_loader, desc="train...", total=len(self.train_loader)):out = self.net(img_tensors)loss = self.loss(out, targets)self.opt.zero_grad()loss.backward()self.opt.step()sum_loss += loss.item()pred_cls = torch.argmax(out, dim=1)target_cls = torch.argmax(targets, dim=1)accuracy = torch.mean(torch.eq(pred_cls, target_cls).to(torch.float32))sum_acc += accuracy.item()avg_loss = sum_loss / len(self.train_loader)avg_acc = sum_acc / len(self.train_loader)print(f'train:loss{round(avg_loss, 3)} acc:{round(avg_acc, 3)}')self.writer.add_scalars("loss", {"train_avg_loss": avg_loss}, epoch)self.writer.add_scalars("acc", {"train_avg_acc": avg_acc}, epoch)def test(self,epoch):sum_loss = 0sum_acc = 0for img_tensors, targets in tqdm(self.test_loader, desc="test...", total=len(self.test_loader)):out = self.net(img_tensors)loss = self.loss(out, targets)sum_loss += loss.item()pred_cls = torch.argmax(out, dim=1)target_cls = torch.argmax(targets, dim=1)accuracy = torch.mean(torch.eq(pred_cls, target_cls).to(torch.float32))sum_acc += accuracy.item()avg_loss = sum_loss / len(self.test_loader)avg_acc = sum_acc / len(self.test_loader)print(f'test:loss{round(avg_loss, 3)} acc:{round(avg_acc, 3)}')self.writer.add_scalars("loss", {"test_avg_loss": avg_loss}, epoch)self.writer.add_scalars("acc", {"test_avg_acc": avg_acc}, epoch)if avg_loss<self.min_loss:self.min_loss=min(self.min_loss,avg_loss)torch.save(self.net.state_dict(), self.weight_path)def run(self):for epo in range(100):self.train(epo)self.test(epo)if __name__ == '__main__':trainer=TrainAndTest()trainer.run()

精度的计算:

                比如通过网络出现的维度是(1,2),其数值是[[0.9 , 0.1]](0.9与0.1表示预测的两个类别的概率)。我们通过maxarg取到其中最大的索引0,与之前真实的标签0或者1做比较。从而可以得出结果

 

相关文章:

pytorch实现水果2分类(蓝莓,苹果)

1.数据集的路径&#xff0c;结构 dataset.py 目的&#xff1a; 输入&#xff1a;没有输入&#xff0c;路径是写死了的。 输出&#xff1a;返回的是一个对象&#xff0c;里面有self.data。self.data是一个列表&#xff0c;里面是&#xff08;图片路径.jpg&#xff0c;标签&…...

Redis实践经验

优雅的Key结构 Key实践约定&#xff1a; 遵循基本格式&#xff1a;[业务名称]:[数据名]:id例&#xff1a;login:user:10长度步超过44字节&#xff08;版本不同&#xff0c;上限不同&#xff09;不包含特殊字符 优点&#xff1a; 可读性强避免key冲突方便管理节省内存&#x…...

分类题解清单

目录 简介MySQL题一、聚合函数二、排序和分组三、高级查询和连接四、子查询五、高级字符串函数 / 正则表达式 / 子句 算法题一、双指针二、滑动窗口三、模拟四、贪心五、矩阵六、排序七、链表八、设计九、前缀和十、哈希表十一、字符串十二、二叉树十三、二分查找十四、回溯十五…...

QUdpSocket 的bind函数详解

QUdpSocket 是 Qt 框架中用于处理 UDP 网络通信的类。bind 函数是此类中的一个重要方法&#xff0c;它用于将 QUdpSocket 对象绑定到一个特定的端口上&#xff0c;以便在该端口上接收 UDP 数据包。 函数原型 在 Qt 中&#xff0c;bind 函数的原型通常如下所示&#xff1a; b…...

[spring] Spring MVC - security(下)

[spring] Spring MVC - security&#xff08;下&#xff09; callback 一下&#xff0c;当前项目结构如下&#xff1a; 这里实现的功能是连接数据库&#xff0c;大范围和 [spring] rest api security 重合 数据库连接 - 明文密码 第一部分使用明文密码 设置数据库 主要就是…...

数据库数据恢复—SQL Server数据库由于存放空间不足报错的数据恢复案例

SQL Server数据库数据恢复环境&#xff1a; 某品牌服务器存储中有两组raid5磁盘阵列。操作系统层面跑着SQL Server数据库&#xff0c;SQL Server数据库存放在D盘分区中。 SQL Server数据库故障&#xff1a; 存放SQL Server数据库的D盘分区容量不足&#xff0c;管理员在E盘中生…...

spring security的demo

参考&#xff1a; https://juejin.cn/post/6844903502003568647 Spring Security 5.7.0弃用 WebSecurityConfigurerAdapter-CSDN博客 创建 Spring Security 配置类 WebSecurityConfigurerAdapter已被弃用 package com.cq.sc.security.config;import org.springframework.c…...

无需构建工具,快速上手Vue2 + ElementUI

无需构建工具&#xff0c;快速上手Vue2 ElementUI 在前端开发的世界中&#xff0c;Vue.js以其轻量级和易用性赢得了开发者的青睐。而Element UI&#xff0c;作为一个基于Vue 2.0的桌面端组件库&#xff0c;提供了丰富的界面组件&#xff0c;使得构建美观且功能丰富的应用变得…...

通信协议_Modbus协议简介

概念介绍 Modbus协议&#xff1a;一种串行通信协议&#xff0c;是Modicon公司&#xff08;现在的施耐德电气Schneider Electric&#xff09;于1979年为使用可编程逻辑控制器&#xff08;PLC&#xff09;通信而发表。Modbus已经成为工业领域通信协议的业界标准&#xff08;De f…...

LabVIEW优化氢燃料电池

太阳能和风能的发展引入了许多新的能量储存方法。随着科技的发展&#xff0c;能源储存和需求平衡的方法也需要不断创新。智慧城市倡导放弃石化化合物&#xff0c;采用环境友好的发电和储能技术。氢气系统和储存链在绿色能源倡议中起着关键作用。然而&#xff0c;氢气密度低&…...

SpringCloudGateway

作用 统一管理&#xff0c;易于监控安全&#xff0c;限流&#xff1a;在网关层就过滤掉非法信息nginx外部网关&#xff0c;gateway内网nginx可以使用Lua或Kong来增强 概念 id:名称随意uri: 被代理的服务地址。id和uri必填&#xff0c;谓词和过滤器非必填谓词&#xff1a;可以…...

Wireshark 对 https 请求抓包并展示为明文

文章目录 1、目标2、环境准备3、Wireshark 基本使用4、操作步骤4.1、彻底关闭 Chrome 进程4.2、配置 SSLKEYLOGFILE [核心步骤]4.3、把文件路径配置到 Wireshark 指定位置4.4、在浏览器发起请求4.5、抓包配置4.6、过滤4.6.1、过滤域名 http.host contains "baidu.com4.6.2…...

如何在Ubuntu环境下使用加速器配置Docker环境

一、安装并打开加速器 这个要根据每个加速器的情况来安装并打开&#xff0c;一般是会开放一个代理端口&#xff0c;比如1087 二、安装Docker https://docs.docker.com/engine/install/debian/#install-using-the-convenience-script 三、配置Docker使用加速器 3.1 修改配置…...

2.5 C#视觉程序开发实例1----CamManager实现模拟相机采集图片

2.5 C#视觉程序开发实例1----CamManager实现模拟相机采集图片 1 目标效果视频 CamManager 2 CamManager读取本地文件时序 3 BD_Vision_Utility添加代码 3.0 导入链接库 BD_OperatorSets.dllSystem.Windows.Forms.dllOpencvSharp 3.1 导入VisionParam中创建的文件Util_FileO…...

算法简介:什么是算法?——定义、历史与应用详解

引言 在现代计算机科学中&#xff0c;算法是一个核心概念。无论是编程还是数据分析&#xff0c;算法都扮演着至关重要的角色。在这篇博客中&#xff0c;我们将深入探讨算法的定义、历史背景以及它在计算机科学中的地位和实际应用。 什么是算法&#xff1f; 算法是解决特定问题…...

xss攻击

一、xss攻击简介 1、OWASP TOP 10之一,XSS被称为跨站脚本攻击(Cross-site-scripting)2、主要基于java script(JS)完成恶意攻击行为。JS可以非常灵活的操作html、css和浏览器,这使得XSS攻击的“想象”空间特别大。3、XSS通过将精心构造代码(JS)代码注入到网页中&#xff0c;并由…...

Android 性能优化之布局优化

文章目录 Android 性能优化之布局优化绘制原理双缓冲机制布局加载原理检测耗时常规方式AOP方式获取控件加载耗时 布局优化AsyncLayoutInflater方案Compose方案减少布局层级和复杂度避免过度绘制 Android 性能优化之布局优化 绘制原理 CPU&#xff1a;负责执行应用层的measure…...

TCP 握手数据流

这张图详细描述了 TCP 握手过程中&#xff0c;从客户端发送 SYN 包到服务器最终建立连接的整个数据流转过程&#xff0c;包括网卡、内核、进程中的各个环节。下面对每个步骤进行详细解释&#xff1a; 客户端到服务器的初始连接请求 客户端发送 SYN 包&#xff1a; 客户端发起…...

MDA协议

MDA协议通常指消息摘要算法&#xff08;Message Digest Algorithm&#xff09;&#xff0c;在计算机安全和密码学中被广泛用于数据完整性验证和认证。以下是对MDA协议的详细介绍&#xff1a; 1. 概述 MDA协议是一类哈希函数&#xff0c;用于生成固定长度的消息摘要或哈希值。…...

always块敏感列表的相关报错,

在综合的时候&#xff0c;报错如下 Synthesis synth_1 [Synth 8-91] ambiguous clock in event control ["E:/FPGA/FPGA_project/handwrite_fft/handwrite_fft.srcs/sources_1/new/reg_s2p.v":140] 猜测报错原因&#xff08;暂时没有时间寻找原因&#xff0c;后续在…...

高效解决Windows 11 LTSC系统Microsoft Store缺失的完整实战指南

高效解决Windows 11 LTSC系统Microsoft Store缺失的完整实战指南 【免费下载链接】LTSC-Add-MicrosoftStore Add Windows Store to Windows 11 24H2 LTSC 项目地址: https://gitcode.com/gh_mirrors/ltscad/LTSC-Add-MicrosoftStore Windows 11 24H2 LTSC版本以其卓越的…...

基于GAN的AI图像水印移除工具VeoWatermarkRemover实战指南

1. 项目概述&#xff1a;一个开源图像水印移除工具 最近在整理一些老照片和网上下载的素材时&#xff0c;经常被图片上那些碍眼的水印、Logo或者时间戳困扰。手动用PS处理&#xff0c;费时费力&#xff0c;而且对批量操作极不友好。直到我发现了GitHub上一个名为“VeoWatermar…...

try-catch到底有没有性能开销

有一种说法是”try-catch 有性能开销&#xff0c;关键路径上不要用”。另一种说法是”try-catch 不抛异常的话没有开销”。这两种说法都不全对&#xff0c;开销在哪里要看具体用法。try-catch 本身不贵&#xff0c;异常对象才贵JVM 里&#xff0c;try-catch 的实现方式是在字节…...

H5GG完整指南:如何用JavaScript和HTML5轻松修改iOS游戏内存

H5GG完整指南&#xff1a;如何用JavaScript和HTML5轻松修改iOS游戏内存 【免费下载链接】H5GG an iOS Mod Engine with JavaScript APIs & Html5 UI 项目地址: https://gitcode.com/gh_mirrors/h5/H5GG 你是否曾经想过修改iOS游戏中的数值&#xff0c;却因为复杂的越…...

深度解析:三合一技术方案破解Cursor AI编辑器限制的终极指南

深度解析&#xff1a;三合一技术方案破解Cursor AI编辑器限制的终极指南 【免费下载链接】cursor-free-vip [Support 0.45]&#xff08;Multi Language 多语言&#xff09;自动注册 Cursor Ai &#xff0c;自动重置机器ID &#xff0c; 免费升级使用Pro 功能: Youve reached yo…...

从‘黑盒子’到清晰电路:用替代定理‘破译’未知网络N的VCR(图解+方程双解法)

从‘黑盒子’到清晰电路&#xff1a;用替代定理‘破译’未知网络N的VCR&#xff08;图解方程双解法&#xff09; 在电子工程实践中&#xff0c;工程师们常常会遇到一种令人头疼的"黑盒子"——那些内部结构不明、数据手册不全的电路模块。面对这样的未知网络&#xff…...

告别Office安装烦恼:3分钟搞定微软办公套件自动部署

告别Office安装烦恼&#xff1a;3分钟搞定微软办公套件自动部署 【免费下载链接】LKY_OfficeTools 一键自动化 下载、安装、激活 Office 的利器。 项目地址: https://gitcode.com/GitHub_Trending/lk/LKY_OfficeTools 还在为繁琐的Office安装流程而头疼吗&#xff1f;一…...

保姆级避坑指南:从模之屋PMX到Unity,搞定Blender导出FBX的纹理丢失问题

保姆级避坑指南&#xff1a;从模之屋PMX到Unity&#xff0c;搞定Blender导出FBX的纹理丢失问题 如果你是一位二次元风格游戏开发者或MMD模型爱好者&#xff0c;那么从模之屋下载PMX模型后&#xff0c;在Blender中处理并导出为FBX格式&#xff0c;最后导入Unity的过程中&#xf…...

Armv9 SME2架构下BFloat16计算优化与机器学习加速

1. SME2指令集与BFloat16计算优化解析在Armv9架构的SME2扩展中&#xff0c;BFloat16&#xff08;简称BF16&#xff09;支持成为机器学习加速的关键特性。这种16位浮点格式通过截断IEEE 754单精度浮点的尾数位&#xff08;从23位减至7位&#xff09;&#xff0c;同时保留完整的8…...

因果推理第四层盲区:为什么关联≠因果

因果推理第四层盲区&#xff1a;为什么关联≠因果 副标题: 从Pearl因果阶梯到知识库因果链&#xff0c;AI如何跨越观测vs建模的鸿沟痛点&#xff1a;为什么你的AI只能"描述"不能"规划"&#xff1f; 你有没有遇到过这样的情况&#xff1a; AI能告诉你"…...