当前位置: 首页 > article >正文

基于卷积神经网络和Pyqt5的猫狗识别小程序

任务描述

猫狗分类任务(Dogs vs Cats)是Kaggle平台在2013年举办的一个经典计算机视觉竞赛。官方给出的Kaggle Dogs vs Cats 数据集中包括由12500张猫咪图片和12500张狗狗图片组成的训练集,12500张未标记照片组成的测试集。选手需要在规定时间内搭建模型,并在测试集中取得尽可能高的准确率。

数据集中的图片长这样👆,是不是还蛮贴近生活呢~

本文将设计一个简单的三层卷积神经网络,完成分类任务。基于Pyqt5设计交互界面,在测试集中检验分类结果。工作量不大非常适合新手作为入门项目~最终功能实现如下图所示。


 代码介绍

我觉得新手最有成就感的事情莫过于跑通一个功能了,我们抛开原理不谈,先完整运行程序。代码我是用pycharm写的,新建一个classfy_program项目,在classfy_program项目中创建两个py文件,分别用于训练和测试数据集。此外还创建data文件夹,用于存放训练和测试用的图片。

cat_dog_test.py代码:

import sys
import torch
import torch.nn as nn
from PyQt5.QtWidgets import QApplication, QWidget, QLabel, QPushButton, QVBoxLayout, QFileDialog
from PyQt5.QtGui import QPixmap
from PyQt5.QtCore import Qt
from torchvision import transforms
from PIL import Image# ---------------- 定义模型结构 ----------------
class CNN(nn.Module):def __init__(self):super(CNN, self).__init__()self.conv1 = nn.Conv2d(3, 32, kernel_size=3, padding=1)self.conv2 = nn.Conv2d(32, 64, kernel_size=3, padding=1)self.conv3 = nn.Conv2d(64, 128, kernel_size=3, padding=1)self.pool = nn.MaxPool2d(kernel_size=2, stride=2)self.fc1 = nn.Linear(128 * 28 * 28, 512)self.fc2 = nn.Linear(512, 2)  # 猫 vs 狗def forward(self, x):x = self.pool(torch.relu(self.conv1(x)))x = self.pool(torch.relu(self.conv2(x)))x = self.pool(torch.relu(self.conv3(x)))x = x.view(-1, 128 * 28 * 28)x = torch.relu(self.fc1(x))x = self.fc2(x)return x# ---------------- 图像预处理 ----------------
transform = transforms.Compose([transforms.Resize((224, 224)),transforms.ToTensor(),transforms.Normalize([0.485, 0.456, 0.406],[0.229, 0.224, 0.225]),
])# 类别名称(根据你的训练数据文件夹顺序)
class_names = ['猫', '狗']# ---------------- 主窗口 ----------------
class CatDogClassifier(QWidget):def __init__(self):super().__init__()self.setWindowTitle("猫狗分类器")self.setGeometry(100, 100, 400, 500)# 加载模型self.device = torch.device("cuda" if torch.cuda.is_available() else "cpu")self.model = CNN().to(self.device)self.model.load_state_dict(torch.load("model_checkpoint.pth", map_location=self.device))self.model.eval()# 创建界面组件self.image_label = QLabel("请点击按钮加载图片", self)self.image_label.setAlignment(Qt.AlignCenter)self.image_label.setFixedSize(300, 300)self.result_label = QLabel("", self)self.result_label.setAlignment(Qt.AlignCenter)self.load_button = QPushButton("加载图片", self)self.load_button.clicked.connect(self.load_image)# 布局layout = QVBoxLayout()layout.addWidget(self.image_label)layout.addWidget(self.result_label)layout.addWidget(self.load_button)self.setLayout(layout)def load_image(self):file_path, _ = QFileDialog.getOpenFileName(self, "选择图片", "", "Image Files (*.png *.jpg *.jpeg)")if file_path:# 显示图片pixmap = QPixmap(file_path)pixmap = pixmap.scaled(300, 300, Qt.KeepAspectRatio)self.image_label.setPixmap(pixmap)# 进行预测image = Image.open(file_path).convert("RGB")input_tensor = transform(image).unsqueeze(0).to(self.device)with torch.no_grad():output = self.model(input_tensor)_, predicted = torch.max(output, 1)result = class_names[predicted.item()]self.result_label.setText(f"预测结果:{result}")# ---------------- 运行主程序 ----------------
if __name__ == "__main__":app = QApplication(sys.argv)window = CatDogClassifier()window.show()sys.exit(app.exec_())

cat_dog_train.py代码:

import torch
import torch.nn as nn
import torch.optim as optim
from torch.utils.data import DataLoader
from torchvision import datasets, transforms# 数据预处理与加载
transform = transforms.Compose([transforms.Resize((224,224)),  # 调整图片尺寸为224x224transforms.ToTensor(),transforms.Normalize(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225]),  # 适用于自定义模型的标准化
])# 设置训练集和验证集
train_dataset = datasets.ImageFolder(root='data/train', transform=transform)train_loader = DataLoader(train_dataset, batch_size=32, shuffle=True)# 自定义卷积神经网络模型
class CNN(nn.Module):def __init__(self):super(CNN, self).__init__()# 第一层卷积self.conv1 = nn.Conv2d(3, 32, kernel_size=3, padding=1)  # 输入通道为3(RGB),输出通道为32self.conv2 = nn.Conv2d(32, 64, kernel_size=3, padding=1)self.conv3 = nn.Conv2d(64, 128, kernel_size=3, padding=1)# 池化层self.pool = nn.MaxPool2d(kernel_size=2, stride=2)# 全连接层self.fc1 = nn.Linear(128 * 28 * 28, 512)  # 调整为经过池化后的维度self.fc2 = nn.Linear(512, 2)  # 输出为2(猫 vs 狗)def forward(self, x):x = self.pool(torch.relu(self.conv1(x)))  # 第一层卷积 + ReLU + 最大池化x = self.pool(torch.relu(self.conv2(x)))  # 第二层卷积 + ReLU + 最大池化x = self.pool(torch.relu(self.conv3(x)))  # 第三层卷积 + ReLU + 最大池化x = x.view(-1, 128 * 28 * 28)  # 展平x = torch.relu(self.fc1(x))  # 全连接层 + ReLUx = self.fc2(x)  # 输出层return x# 实例化模型
model = CNN()# 使用GPU训练(如果可用)
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
model.to(device)  # 将模型移动到GPU(如果可用)# 损失函数与优化器
criterion = nn.CrossEntropyLoss()
optimizer = optim.Adam(model.parameters(), lr=0.0001)# 训练模型
num_epochs = 10
for epoch in range(num_epochs):model.train()running_loss = 0.0correct = 0total = 0# 训练阶段for images, labels in train_loader:images, labels = images.to(device), labels.to(device)  # 将数据移动到GPU(如果可用)optimizer.zero_grad()outputs = model(images)loss = criterion(outputs, labels)loss.backward()optimizer.step()running_loss += loss.item()_, predicted = torch.max(outputs, 1)total += labels.size(0)correct += (predicted == labels).sum().item()# 输出每轮训练损失与准确率print(f"Epoch [{epoch + 1}/{num_epochs}], Loss: {running_loss / len(train_loader):.4f}, Accuracy: {100 * correct / total:.2f}%")# 保存模型检查点torch.save(model.state_dict(), 'model_checkpoint.pth')print("Training complete!")

data数据集的下载链接在这里了:

Download Kaggle Cats and Dogs Dataset from Official Microsoft Download Center


希望看到这里的大家能够点一个小小的赞❤❤,后续会持续更新更多内容~~

相关文章:

基于卷积神经网络和Pyqt5的猫狗识别小程序

任务描述 猫狗分类任务(Dogs vs Cats)是Kaggle平台在2013年举办的一个经典计算机视觉竞赛。官方给出的Kaggle Dogs vs Cats 数据集中包括由12500张猫咪图片和12500张狗狗图片组成的训练集,12500张未标记照片组成的测试集。选手需要在规定时间…...

Hadoop 和 Spark 生态系统中的核心组件

以下是 Hadoop 和 Spark 生态系统的核心组件及其功能: Hadoop 生态核心组件 1. HDFS(Hadoop 分布式文件系统) - 命令/工具: hdfs 命令(如 hdfs dfs -put 等)。 - 作用:分布式存储海量数据&a…...

解锁健康养生新境界

在追求高品质生活的当下,健康养生早已超越 “治未病” 的传统认知,成为贯穿全生命周期的生活艺术。它如同精密的交响乐,需饮食、运动、心理与生活习惯多维度协奏,方能奏响生命的强音。 饮食养生讲究 “顺时、适性”。遵循二十四节…...

MQTT:轻量级物联网通信协议详解

引言 在物联网(IoT)迅速发展的今天,设备之间的高效通信变得至关重要。MQTT(Message Queuing Telemetry Transport)作为一种轻量级的发布/订阅消息传输协议,因其低带宽、低延迟和易于实现的特性&#xff0c…...

C 语言逻辑运算符:组合判断,构建更复杂的条件

各类资料学习下载合集 ​​https://pan.quark.cn/s/8c91ccb5a474​​ 在 C 语言编程中,我们已经学习了如何使用比较运算符(如 ​​==​​​, ​​<​​​, ​​>​​)来判断两个值之间的关系,从而得到“真”或“假”的结果。但很多时候,我们需要根据多个条件的组合…...

基于OpenCV的人脸识别:EigenFaces算法

文章目录 引言一、概述二、代码解析1. 准备工作2. 加载训练图像3. 设置标签4. 准备测试图像5. 创建和训练识别器6. 进行预测7. 显示结果 三、代码要点总结 引言 人脸识别是计算机视觉领域的一个重要应用&#xff0c;今天我将通过一个实际案例来展示如何使用OpenCV中的EigenFac…...

【深度学习新浪潮】智能追焦技术全解析:从算法到设备应用

一、智能追焦技术概述 智能追焦是基于人工智能和自动化技术的对焦功能,通过深度学习算法识别并持续跟踪移动物体(如人、动物、运动器械等),实时调整焦距以保持主体清晰,显著提升动态场景拍摄成功率。其核心优势包括: 精准性:AI 算法优化复杂运动轨迹追踪(如不规则移动…...

网络研讨会开发注册中, 5月15日特励达力科,“了解以太网”

在线研讨会主题 Understanding Ethernet - from basics to testing & optimization 了解以太网 - 从基础知识到测试和优化 注册链接# https://register.gotowebinar.com/register/2823468241337063262 时间 北京时间 2025 年 5 月 15 日 星期四 下午 3:30 - 4:30 适宜…...

在 Spring Boot 中实现动态线程池的全面指南

动态线程池是一种线程池管理方案&#xff0c;允许在运行时根据业务需求动态调整线程池参数&#xff08;如核心线程数、最大线程数、队列容量等&#xff09;&#xff0c;以优化资源利用率和系统性能。在 Spring Boot 中&#xff0c;动态线程池可以通过 Java 的 ThreadPoolExecut…...

STL?vector!!!

一、前言 之前我们借助手撕string加深了类和对象相关知识&#xff0c;今天我们将一起手撕一个vector&#xff0c;继续深化类和对象、动态内存管理、模板的相关知识 二、vector相关的前置知识 1、什么是vector&#xff1f; vector是一个STL库中提供的类模板&#xff0c;它是存储…...

从黔西游船侧翻事件看极端天气预警的科技防线——疾风气象大模型如何实现精准防御?

近日,贵州省黔西市一起载人游船侧翻事故令人痛心。调查显示,事发时当地突遇强风暴雨,水面突发巨浪导致船只失控。这一事件再次凸显:在极端天气频发的时代,传统“经验式防灾”已不足够,唯有依靠智能化的气象预警技术,才能筑牢安全底线。 极端天气预警的痛点:为什么传统方…...

LDO与DCDC总结

目录 1. 工作原理 2. 性能对比 3. 选型关键因素 4. 典型应用 总结 1. 工作原理 LDO LDO通过线性调节方式实现降压&#xff0c;输入电压需略高于输出电压&#xff08;压差通常为0.2-2V&#xff09;&#xff0c;利用内部PMOS管或PNP三极管调整压差以稳定输出电压。其结构简单…...

FastChat部署大模型

一、前提条件 1、系统环境&#xff08;使用的 autodl 算力平台&#xff09; 2、安装相关库 安装 modescope pip3 install -U modelscope # 或使用下方命令 # pip3 install -U modelscope -i https://mirror.sjtu.edu.cn/pypi/web/simple安装 fastchat git clone https://gi…...

智汇云舟亮相第二十七届北京科博会

5月8日&#xff0c;备受瞩目的第二十七届中国北京国际科技产业博览会&#xff08;以下简称&#xff1a;北京科博会&#xff09;在国家会议中心盛大开幕。作为我国科技领域的重要盛会&#xff0c;北京科博会汇聚了众多前沿科技成果与创新力量&#xff0c;为全球科技产业交流搭建…...

Redis最新入门教程

文章目录 Redis最新入门教程1.安装Redis2.连接Redis3.Redis环境变量配置4.入门Redis4.1 Redis的数据结构4.2 Redis的Key4.3 Redis-String4.4 Redis-Hash4.5 Redis-List4.6 Redis-Set4.7 Redis-Zset 5.在Java中使用Redis6.缓存雪崩、击穿、穿透6.1 缓存雪崩6.2 缓冲击穿6.3 缓冲…...

北斗三号手持终端设备功能与应用

北斗三号卫星系统是我国自主建设、独立运行的全球卫星导航系统。通过多颗不同轨道卫星组成的&#xff0c;这些卫星持续向地球发射携带精确时间和位置信息的信号。地面上的北斗手持终端接收到至少四颗卫星信号后&#xff0c;利用信号传播时间差&#xff0c;通过三角函数等算法&a…...

提升编程效率的利器:Zed高性能多人协作代码编辑器

在当今这个快节奏的开发环境中&#xff0c;一个高效、灵活的代码编辑器无疑对开发者们起着至关重要的支持作用。Zed&#xff0c;作为来自知名编辑器Atom和语法解析器Tree-sitter的创造者的心血之作&#xff0c;正是这样一款高性能支持多人合作的编辑神器。本文将带领大家深入探…...

【数据机构】2. 线性表之“顺序表”

- 第 96 篇 - Date: 2025 - 05 - 09 Author: 郑龙浩/仟墨 【数据结构 2】 文章目录 数据结构 - 2 -线性表之“顺序表”1 基本概念2 顺序表(一般为数组)① 基本介绍② 分类 (静态与动态)③ 动态顺序表的实现**test.c文件:****SeqList.h文件:****SeqList.c文件:** 数据结构 - 2 …...

Java如何获取电脑分辨率?

以下是一个 Java 程序示例&#xff0c;用于获取电脑的主屏幕分辨率&#xff1a; import java.awt.*; public class ScreenResolutionExample { public static void main(String[] args) { // 获取默认的屏幕设备 GraphicsDevice device GraphicsEnvironm…...

opencv中的图像特征提取

图像的特征&#xff0c;一般是指图像所表达出的该图像的特有属性&#xff0c;其实就是事物的图像特征&#xff0c;由于图像获得的多样性&#xff08;拍摄器材、角度等&#xff09;&#xff0c;事物的图像特征有时并不特别突出或与无关物体混杂在一起&#xff0c;因此图像的特征…...

dropout层

从你提供的图片来看&#xff0c;里面讨论了 Dropout 层&#xff0c;让我为你解释一下它的工作原理和作用。 Dropout 层是什么&#xff1f; Dropout 是一种常用的正则化技术&#xff0c;用于避免神经网络的 过拟合&#xff08;overfitting&#xff09;。过拟合是指模型在训练数…...

极狐GitLab 容器镜像仓库功能介绍

极狐GitLab 是 GitLab 在中国的发行版&#xff0c;关于中文参考文档和资料有&#xff1a; 极狐GitLab 中文文档极狐GitLab 中文论坛极狐GitLab 官网 极狐GitLab 容器镜像库 (BASIC ALL) 您可以使用集成的容器镜像库&#xff0c;来存储每个极狐GitLab 项目的容器镜像。 要为您…...

【JVM-GC调优】

一、预备知识 掌握GC相关的VM参数&#xff0c;会基本的空间调整掌握相关工具明白一点&#xff1a;调优跟应用、环境有关&#xff0c;没有放之四海而皆准的法则 二、调优领域 内存锁竞争cpu占用io 三、确定目标 【低延迟】&#xff1a;CMS、G1&#xff08;低延迟、高吞吐&a…...

shopping mall(document)

shopping mall&#xff08;document&#xff09; 商城的原型&#xff0c;学习&#xff0c;优化&#xff0c;如何比别人做的更好&#xff0c;更加符合大众的习惯 抄别人会陷入一个怪圈&#xff0c;就是已经习惯了&#xff0c;也懒了&#xff0c;也不带思考了。 许多产品会迫于…...

qiankun微前端任意位置子应用

qiankun微前端任意位置子应用 主项目1、安装qiankun2、引入注册3、路由创建4、路由守卫 二、子项目1、安装sh-winter/vite-plugin-qiankun2、main.js配置3、vite.config.js配置 三、问题解决 主项目 1、安装qiankun npm i qiankun -S2、引入注册 创建存放子应用页面 //whpv…...

【LeetCode Solutions】LeetCode 176 ~ 180 题解

CONTENTS LeetCode 176. 第二高的薪水&#xff08;SQL 中等&#xff09;LeetCode 177. 第 N 高的薪水&#xff08;SQL 中等&#xff09;LeetCode 178. 分数排名&#xff08;SQL 中等&#xff09;LeetCode 179. 最大数&#xff08;中等&#xff09;LeetCode 180. 连续出现的数字…...

前端面试每日三题 - Day 29

这是我为准备前端/全栈开发工程师面试整理的第29天每日三题练习&#xff1a; ✅ 题目1&#xff1a;Web Components技术全景解析 核心三要素 Custom Elements&#xff08;自定义元素&#xff09; class MyButton extends HTMLElement {constructor() {super();this.attachShado…...

第十五章,SSL VPN

前言 IPSec 和 SSL 对比 IPSec远程接入场景---client提前安装软件&#xff0c;存在一定的兼容性问题 IPSec协议只能够对感兴趣的流量进行加密保护&#xff0c;意味着接入用户需要不停的调整策略&#xff0c;来适应IPSec隧道 IPSec协议对用户访问权限颗粒度划分的不够详细&…...

spring5.x讲解介绍

Spring 5.x 是 Spring Framework 的重要版本升级&#xff0c;全面拥抱现代 Java 技术栈&#xff0c;其核心改进涵盖响应式编程、Java 8支持、性能优化及开发模式创新。以下从特性、架构和应用场景三个维度详细解析&#xff1a; 一、核心特性与架构改进 Java 8 全面支持 Spring …...

荣耀A8互动娱乐组件部署实录(第3部分:控制端结构与房间通信协议)

作者&#xff1a;曾在 WebSocket 超时里泡了七天七夜的苦命人 一、控制端总体架构概述 荣耀A8控制端主要承担的是“运营支点”功能&#xff0c;也就是开发与运营之间的桥梁。它既不直接参与玩家行为&#xff0c;又控制着玩家的行为逻辑和游戏规则触发机制。控制端的主要职责包…...