当前位置: 首页 > news >正文

Pytorch使用VGG16模型进行预测猫狗二分类

目录

1. VGG16

1.1 VGG16 介绍

1.1.1 VGG16 网络的整体结构

 1.2 Pytorch使用VGG16进行猫狗二分类实战

1.2.1 数据集准备

1.2.2 构建VGG网络

1.2.3 训练和评估模型


 

1. VGG16

1.1 VGG16 介绍

深度学习已经在计算机视觉领域取得了巨大的成功,特别是在图像分类任务中。VGG16是深度学习中经典的卷积神经网络(Convolutional Neural Network,CNN)之一,由牛津大学的Karen Simonyan和Andrew Zisserman在2014年提出。VGG16网络以其深度和简洁性而闻名,是图像分类中的重要里程碑。

VGG16是Visual Geometry Group的缩写,它的名字来源于提出该网络的实验室。VGG16的设计目标是通过增加网络深度来提高图像分类的性能,并展示了深度对于图像分类任务的重要性。VGG16的主要特点是将多个小尺寸的卷积核堆叠在一起,从而形成更深的网络。

1.1.1 VGG16 网络的整体结构

VGG16网络由多个卷积层和全连接层组成。它的整体结构相对简单,所有的卷积层都采用小尺寸的卷积核(通常为3x3),步幅为1,填充为1。每个卷积层后面都会跟着一个ReLU激活函数来引入非线性。

VGG16网络主要由三个部分组成:

  1. 输入层:接受图像输入,通常为224x224大小的彩色图像(RGB)。

  2. 卷积层:VGG16包含13个卷积层,其中包括五个卷积块。

  3. 全连接层:在卷积层后面是3个全连接层,用于最终的分类。

VGG16网络结构如下图:

watermark,type_ZmFuZ3poZW5naGVpdGk,shadow_10,text_aHR0cHM6Ly9ibG9nLmNzZG4ubmV0L3dlaXhpbl80NDc5MTk2NA==,size_16,color_FFFFFF,t_70#pic_center

1、一张原始图片被resize到(224,224,3)。
2、conv1两次[3,3]卷积网络,输出的特征层为64,输出为(224,224,64),再2X2最大池化,输出net为(112,112,64)。
3、conv2两次[3,3]卷积网络,输出的特征层为128,输出net为(112,112,128),再2X2最大池化,输出net为(56,56,128)。
4、conv3三次[3,3]卷积网络,输出的特征层为256,输出net为(56,56,256),再2X2最大池化,输出net为(28,28,256)。
5、conv4三次[3,3]卷积网络,输出的特征层为512,输出net为(28,28,512),再2X2最大池化,输出net为(14,14,512)。
6、conv5三次[3,3]卷积网络,输出的特征层为512,输出net为(14,14,512),再2X2最大池化,输出net为(7,7,512)。
7、利用卷积的方式模拟全连接层,效果等同,输出net为(1,1,4096)。共进行两次。
8、利用卷积的方式模拟全连接层,效果等同,输出net为(1,1,1000)。
最后输出的就是每个类的预测。

 1.2 Pytorch使用VGG16进行猫狗二分类实战

在这一部分,我们将使用PyTorch来实现VGG16网络,用于猫狗预测的二分类任务。我们将对VGG16的网络结构进行适当的修改,以适应我们的任务。

1.2.1 数据集准备

首先,我们需要准备用于猫狗二分类的数据集。数据集可以从Kaggle上下载,其中包含了大量的猫和狗的图片。在下载数据集后,我们需要将数据集划分为训练集和测试集。训练集文件夹命名为train,其中建立两个文件夹分别为cat和dog,每个文件夹里存放相应类别的图片。测试集命名为test,同理。

import torch
import torchvision
import torchvision.transforms as transforms# 定义数据转换
transform = transforms.Compose([transforms.Resize((224, 224)),transforms.ToTensor(),transforms.Normalize((0.5, 0.5, 0.5), (0.5, 0.5, 0.5))
])# 加载数据集
train_dataset = ImageFolder("train", transform=transform)
test_dataset = ImageFolder("test", transform=transform)train_loader = DataLoader(train_dataset, batch_size=batch_size, shuffle=True)
test_loader = DataLoader(test_dataset, batch_size=batch_size)

1.2.2 构建VGG网络

import torch.nn as nnclass VGG16(nn.Module):def __init__(self):super(VGG16, self).__init__()self.features = nn.Sequential(# Block 1nn.Conv2d(3, 64, kernel_size=3, padding=1),nn.ReLU(inplace=True),nn.Conv2d(64, 64, kernel_size=3, padding=1),nn.ReLU(inplace=True),nn.MaxPool2d(kernel_size=2, stride=2),# Block 2nn.Conv2d(64, 128, kernel_size=3, padding=1),nn.ReLU(inplace=True),nn.Conv2d(128, 128, kernel_size=3, padding=1),nn.ReLU(inplace=True),nn.MaxPool2d(kernel_size=2, stride=2),# Block 3nn.Conv2d(128, 256, kernel_size=3, padding=1),nn.ReLU(inplace=True),nn.Conv2d(256, 256, kernel_size=3, padding=1),nn.ReLU(inplace=True),nn.Conv2d(256, 256, kernel_size=3, padding=1),nn.ReLU(inplace=True),nn.MaxPool2d(kernel_size=2, stride=2),# Block 4nn.Conv2d(256, 512, kernel_size=3, padding=1),nn.ReLU(inplace=True),nn.Conv2d(512, 512, kernel_size=3, padding=1),nn.ReLU(inplace=True),nn.Conv2d(512, 512, kernel_size=3, padding=1),nn.ReLU(inplace=True),nn.MaxPool2d(kernel_size=2, stride=2),# Block 5nn.Conv2d(512, 512, kernel_size=3, padding=1),nn.ReLU(inplace=True),nn.Conv2d(512, 512, kernel_size=3, padding=1),nn.ReLU(inplace=True),nn.Conv2d(512, 512, kernel_size=3, padding=1),nn.ReLU(inplace=True),nn.MaxPool2d(kernel_size=2, stride=2),)self.classifier = nn.Sequential(nn.Linear(512 * 7 * 7, 4096),nn.ReLU(inplace=True),nn.Dropout(),nn.Linear(4096, 4096),nn.ReLU(inplace=True),nn.Dropout(),nn.Linear(4096, 2)  # 输出层,二分类任务)def forward(self, x):x = self.features(x)x = torch.flatten(x, 1)  # 展开特征图x = self.classifier(x)return x# 初始化VGG16模型
vgg16 = VGG16()

在上述代码中,我们定义了一个VGG16类,其中self.features部分包含了5个卷积块,self.classifier部分包含了3个全连接层。

1.2.3 训练和评估模型

import torch.optim as optim# 定义超参数
batch_size = 32
learning_rate = 0.001
num_epochs = 10model = VGG16()
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
model.to(device)# 定义损失函数和优化器
criterion = nn.CrossEntropyLoss()
optimizer = optim.SGD(model.parameters(), lr=learning_rate, momentum=0.9)# 训练模型
total_step = len(train_loader)
for epoch in range(num_epochs):for i, (images, labels) in enumerate(train_loader):images = images.to(device)labels = labels.to(device)# 前向传播outputs = model(images)loss = criterion(outputs, labels)# 反向传播和优化optimizer.zero_grad()loss.backward()optimizer.step()if (i + 1) % 100 == 0:print(f"Epoch [{epoch+1}/{num_epochs}], Step [{i+1}/{total_step}], Loss: {loss.item()}")
torch.save(model,'model/vgg16.pth')
# 测试模型
model.eval()
with torch.no_grad():correct = 0total = 0for images, labels in test_loader:images = images.to(device)labels = labels.to(device)outputs = model(images)print(outputs)_, predicted = torch.max(outputs.data, 1)total += labels.size(0)correct += (predicted == labels).sum().item()print(f"Accuracy on test images: {(correct / total) * 100}%")

在训练模型时,我们使用交叉熵损失函数(CrossEntropyLoss)作为分类任务的损失函数,并采用随机梯度下降(SGD)作为优化器。同时,我们将模型移动到GPU(如果可用)来加速训练过程。

 

 


 

 

相关文章:

Pytorch使用VGG16模型进行预测猫狗二分类

目录 1. VGG16 1.1 VGG16 介绍 1.1.1 VGG16 网络的整体结构 1.2 Pytorch使用VGG16进行猫狗二分类实战 1.2.1 数据集准备 1.2.2 构建VGG网络 1.2.3 训练和评估模型 1. VGG16 1.1 VGG16 介绍 深度学习已经在计算机视觉领域取得了巨大的成功,特别是在图像分类任…...

安装nvm使用nvm管理node切换npm镜像后使用vue ui管理构建项目成功

如果安装nvm前已经单独安装过node.js的请先自行卸载原有node和环境变量里面的配置; 亲测成功,有哪些问题可以在评论区发消息或者私聊我 1、安装nvm的步骤如下 下载nvm安装包 在nvm的GitHub仓库,如下是国内镜像仓库: 点击这里跳…...

在线LaTeX公式编辑器编辑公式

在线LaTeX公式编辑器编辑公式 在编辑LaTex文档时候,需要输入公式,可以使用在线LaTeX公式编辑器编辑公式,其链接为: 在线LaTeX公式编辑器,https://www.latexlive.com/home 图1 在线LaTeX公式编辑器界面 图2 在线LaTeX公式编辑器…...

【C、C++】学习0

C、C学习路线 C语法:变量、条件、循环、字符串、数组、函数、结构体等指针、内存管理推荐书籍:《C Primer Plus》、《C和指针》、《C专家编程》 CC语言基础C的面向对象(封装、继承与多态)特性泛型模板STL等等推荐书籍(…...

python GUI nicegui初识一(登录界面创建)

最近尝试了python的nicegui库,虽然可能也有一些不足,但个人感觉对于想要开发不过对ui设计感到很麻烦的人来说是很友好的了,毕竟nicegui可以利用TailwindCSS和Quasar进行ui开发,并且也支持定制自己的css样式。 这里记录一下自己利…...

【单片机】51单片机串口的收发实验,串口程序

这段代码是使用C语言编写的用于8051单片机的串口通信程序。它实现了以下功能: 引入必要的头文件,包括reg52.h、intrins.h、string.h、stdio.h和stdlib.h。 定义了常量FSOC和BAUD,分别表示系统时钟频率和波特率。 定义了一个发送数据的函数…...

【bug】记录一次使用Swiper插件时loop属性和slidersPerView属性冲突问题

简言 最近在vue3使用swiper时,突然发现loop属性和slides-per-view属性同时存在启用时,loop生效,下一步只能生效一次的bug,上一步却是好的。非常滴奇怪。 解决过程 分析属性是否使用错误。 loop是循环模式,布尔型。 …...

云原生应用里的服务发现

服务定义: 服务定义是声明给定服务如何被消费者/客户端使用的方式。在建立服务之间的同步通信通道之前,它会与消费者共享。 同步通信中的服务定义: 微服务可以将其服务定义发布到服务注册表(或由微服务所有者手动发布)…...

【零基础学Rust | 基础系列 | 基础语法】变量,数据类型,运算符,控制流

文章目录 简介:一,变量1,变量的定义2,变量的可变性3,变量的隐藏 二、数据类型1,标量类型2,复合类型 三,运算符1,算术运算符2,比较运算符3,逻辑运算…...

运输层---概述

目录 运输层主要内容一.概述和传输层服务1.1 概述1.2 传输服务和协议1.3 传输层 vs. 网络层1.4 Internet传输层协议 二. 多路复用与多路分解(解复用)2.1 概述2.2 无连接与面向连接的多路分解(解复用)2.3面向连接的多路复用*2.4 We…...

高速公路巡检无人机,为何成为公路巡检的主流工具

随着无人机技术的不断发展,无人机越来越多地应用于各个领域。其中,在高速公路领域,高速公路巡检无人机已成为公路巡检的得力助手。高速公路巡检无人机之所以能够成为公路巡检中的主流工具,主要是因为其具备以下三大特性。 一、高速…...

仓库管理系统有哪些功能,如何对仓库进行有效管理

阅读本文,您可以了解:1、仓库管理系统有哪些功能;2、如何对仓库进行有效管理。 仓库是制造业的开端,原材料的领料开始。企业的仓库管理是涉及企业生产、企业资金流和企业的经营风险的关键环节。在众多的工业企业、制造型企业、贸…...

Java 比Automic更高效的累加器

1、 java常见的原子类 类 Atomiclnteger、AtomicIntegerArray、AtomicIntegerFieldUpdater、AtomicLongArray、 AtomicLongFieldUpdater、AtomicReference、AtomicReferenceArray 和 AtomicReference- FieldUpdater 常见的原子类使用方法 使用 AtomicReference 来创建一个原…...

antDv table组件滚动截图方法的实现

在开发中经常遇到table内容过多产生滚动的场景,正常情况下不产生滚动进行截图就很好实现,一旦产生滚动就会变得有点棘手。 下面分两种场景阐述解决的方法过程 场景一:右侧不固定列的情况 场景二:右侧固定列的情况 场景一 打开…...

JavaSE【抽象类和接口】(抽象类、接口、实现多个接口、接口的继承)

一、抽象类 在 Java 中,一个类如果被 abstract 修饰称为抽象类,抽象类中被 abstract 修饰的方法称为抽象方法,抽象方法不用 给出具体的实现体。 1.语法 // 抽象类:被 abstract 修饰的类 public abstract class Shape { …...

微信小程序如何跳转H5页面

1、登录微信公众后台,进入【开发->开发管理->业务域名】,点击修改。 2、首先请下载校验文件,并将文件放置在域名根目录下。 我是把文件放在nginx主机的data目录下,然后通过增加nginx.config配置,重启nginx后可…...

C++(20):bit_cast

C++20之前如果想对不同的指针之间做类型转换需要通过reinterpret_cast,对于整数与指针之前的转换也需要通过reinterpret_cast: C++:reinterpret_cast_c++ reparant_cast_风静如云的博客-CSDN博客 但是reinterpret_cast的缺点是不同的编译环境下,无法包装转型的安全一致。 …...

STM32 低功耗-停止模式

STM32 停止模式 文章目录 STM32 停止模式第1章 低功耗模式简介第2章 停止模式简介2.1 进入停止模式2.1 退出停止模式 第3章 停止模式程序部分总结 第1章 低功耗模式简介 在 STM32 的正常工作中,具有四种工作模式:运行、睡眠、停止以及待机模式。 在系统…...

Hutool中 常用的工具类和方法

文章目录 日期时间工具类 DateUtil日期时间对象-DateTime类型转换工具类 Convert字符串工具类 StrUtil数字处理工具类 NumberUtilJavaBean的工具类 BeanUtil集合操作的工具类 CollUtilMap操作工具类 MapUtil数组工具-ArrayUtil唯一ID工具-IdUtilIO工具类-IoUtil加密解密工具类 …...

K8s(健康检查+滚动更新+优雅停机+弹性伸缩+Prometheus监控+配置分离)

前言 快速配置请直接跳转至汇总配置 K8s SpringBoot实现零宕机发布:健康检查滚动更新优雅停机弹性伸缩Prometheus监控配置分离(镜像复用) 配置 健康检查 健康检查类型:就绪探针(readiness) 存活探针&am…...

挑战杯推荐项目

“人工智能”创意赛 - 智能艺术创作助手:借助大模型技术,开发能根据用户输入的主题、风格等要求,生成绘画、音乐、文学作品等多种形式艺术创作灵感或初稿的应用,帮助艺术家和创意爱好者激发创意、提高创作效率。 ​ - 个性化梦境…...

QMC5883L的驱动

简介 本篇文章的代码已经上传到了github上面,开源代码 作为一个电子罗盘模块,我们可以通过I2C从中获取偏航角yaw,相对于六轴陀螺仪的yaw,qmc5883l几乎不会零飘并且成本较低。 参考资料 QMC5883L磁场传感器驱动 QMC5883L磁力计…...

Mobile ALOHA全身模仿学习

一、题目 Mobile ALOHA:通过低成本全身远程操作学习双手移动操作 传统模仿学习(Imitation Learning)缺点:聚焦与桌面操作,缺乏通用任务所需的移动性和灵活性 本论文优点:(1)在ALOHA…...

python报错No module named ‘tensorflow.keras‘

是由于不同版本的tensorflow下的keras所在的路径不同,结合所安装的tensorflow的目录结构修改from语句即可。 原语句: from tensorflow.keras.layers import Conv1D, MaxPooling1D, LSTM, Dense 修改后: from tensorflow.python.keras.lay…...

Xen Server服务器释放磁盘空间

disk.sh #!/bin/bashcd /run/sr-mount/e54f0646-ae11-0457-b64f-eba4673b824c # 全部虚拟机物理磁盘文件存储 a$(ls -l | awk {print $NF} | cut -d. -f1) # 使用中的虚拟机物理磁盘文件 b$(xe vm-disk-list --multiple | grep uuid | awk {print $NF})printf "%s\n"…...

计算机基础知识解析:从应用到架构的全面拆解

目录 前言 1、 计算机的应用领域:无处不在的数字助手 2、 计算机的进化史:从算盘到量子计算 3、计算机的分类:不止 “台式机和笔记本” 4、计算机的组件:硬件与软件的协同 4.1 硬件:五大核心部件 4.2 软件&#…...

【无标题】湖北理元理律师事务所:债务优化中的生活保障与法律平衡之道

文/法律实务观察组 在债务重组领域,专业机构的核心价值不仅在于减轻债务数字,更在于帮助债务人在履行义务的同时维持基本生活尊严。湖北理元理律师事务所的服务实践表明,合法债务优化需同步实现三重平衡: 法律刚性(债…...

算法打卡第18天

从中序与后序遍历序列构造二叉树 (力扣106题) 给定两个整数数组 inorder 和 postorder ,其中 inorder 是二叉树的中序遍历, postorder 是同一棵树的后序遍历,请你构造并返回这颗 二叉树 。 示例 1: 输入:inorder [9,3,15,20,7…...

Python学习(8) ----- Python的类与对象

Python 中的类(Class)与对象(Object)是面向对象编程(OOP)的核心。我们可以通过“类是模板,对象是实例”来理解它们的关系。 🧱 一句话理解: 类就像“图纸”,对…...

echarts使用graphic强行给图增加一个边框(边框根据自己的图形大小设置)- 适用于无法使用dom的样式

pdf-lib https://blog.csdn.net/Shi_haoliu/article/details/148157624?spm1001.2014.3001.5501 为了完成在pdf中导出echarts图,如果边框加在dom上面,pdf-lib导出svg的时候并不会导出边框,所以只能在echarts图上面加边框 grid的边框是在图里…...