Pytorch使用VGG16模型进行预测猫狗二分类
目录
1. VGG16
1.1 VGG16 介绍
1.1.1 VGG16 网络的整体结构
1.2 Pytorch使用VGG16进行猫狗二分类实战
1.2.1 数据集准备
1.2.2 构建VGG网络
1.2.3 训练和评估模型
1. VGG16
1.1 VGG16 介绍
深度学习已经在计算机视觉领域取得了巨大的成功,特别是在图像分类任务中。VGG16是深度学习中经典的卷积神经网络(Convolutional Neural Network,CNN)之一,由牛津大学的Karen Simonyan和Andrew Zisserman在2014年提出。VGG16网络以其深度和简洁性而闻名,是图像分类中的重要里程碑。
VGG16是Visual Geometry Group的缩写,它的名字来源于提出该网络的实验室。VGG16的设计目标是通过增加网络深度来提高图像分类的性能,并展示了深度对于图像分类任务的重要性。VGG16的主要特点是将多个小尺寸的卷积核堆叠在一起,从而形成更深的网络。
1.1.1 VGG16 网络的整体结构
VGG16网络由多个卷积层和全连接层组成。它的整体结构相对简单,所有的卷积层都采用小尺寸的卷积核(通常为3x3),步幅为1,填充为1。每个卷积层后面都会跟着一个ReLU激活函数来引入非线性。
VGG16网络主要由三个部分组成:
-
输入层:接受图像输入,通常为224x224大小的彩色图像(RGB)。
-
卷积层:VGG16包含13个卷积层,其中包括五个卷积块。
-
全连接层:在卷积层后面是3个全连接层,用于最终的分类。
VGG16网络结构如下图:

1、一张原始图片被resize到(224,224,3)。
2、conv1两次[3,3]卷积网络,输出的特征层为64,输出为(224,224,64),再2X2最大池化,输出net为(112,112,64)。
3、conv2两次[3,3]卷积网络,输出的特征层为128,输出net为(112,112,128),再2X2最大池化,输出net为(56,56,128)。
4、conv3三次[3,3]卷积网络,输出的特征层为256,输出net为(56,56,256),再2X2最大池化,输出net为(28,28,256)。
5、conv4三次[3,3]卷积网络,输出的特征层为512,输出net为(28,28,512),再2X2最大池化,输出net为(14,14,512)。
6、conv5三次[3,3]卷积网络,输出的特征层为512,输出net为(14,14,512),再2X2最大池化,输出net为(7,7,512)。
7、利用卷积的方式模拟全连接层,效果等同,输出net为(1,1,4096)。共进行两次。
8、利用卷积的方式模拟全连接层,效果等同,输出net为(1,1,1000)。
最后输出的就是每个类的预测。
1.2 Pytorch使用VGG16进行猫狗二分类实战
在这一部分,我们将使用PyTorch来实现VGG16网络,用于猫狗预测的二分类任务。我们将对VGG16的网络结构进行适当的修改,以适应我们的任务。
1.2.1 数据集准备
首先,我们需要准备用于猫狗二分类的数据集。数据集可以从Kaggle上下载,其中包含了大量的猫和狗的图片。在下载数据集后,我们需要将数据集划分为训练集和测试集。训练集文件夹命名为train,其中建立两个文件夹分别为cat和dog,每个文件夹里存放相应类别的图片。测试集命名为test,同理。
import torch
import torchvision
import torchvision.transforms as transforms# 定义数据转换
transform = transforms.Compose([transforms.Resize((224, 224)),transforms.ToTensor(),transforms.Normalize((0.5, 0.5, 0.5), (0.5, 0.5, 0.5))
])# 加载数据集
train_dataset = ImageFolder("train", transform=transform)
test_dataset = ImageFolder("test", transform=transform)train_loader = DataLoader(train_dataset, batch_size=batch_size, shuffle=True)
test_loader = DataLoader(test_dataset, batch_size=batch_size)
1.2.2 构建VGG网络
import torch.nn as nnclass VGG16(nn.Module):def __init__(self):super(VGG16, self).__init__()self.features = nn.Sequential(# Block 1nn.Conv2d(3, 64, kernel_size=3, padding=1),nn.ReLU(inplace=True),nn.Conv2d(64, 64, kernel_size=3, padding=1),nn.ReLU(inplace=True),nn.MaxPool2d(kernel_size=2, stride=2),# Block 2nn.Conv2d(64, 128, kernel_size=3, padding=1),nn.ReLU(inplace=True),nn.Conv2d(128, 128, kernel_size=3, padding=1),nn.ReLU(inplace=True),nn.MaxPool2d(kernel_size=2, stride=2),# Block 3nn.Conv2d(128, 256, kernel_size=3, padding=1),nn.ReLU(inplace=True),nn.Conv2d(256, 256, kernel_size=3, padding=1),nn.ReLU(inplace=True),nn.Conv2d(256, 256, kernel_size=3, padding=1),nn.ReLU(inplace=True),nn.MaxPool2d(kernel_size=2, stride=2),# Block 4nn.Conv2d(256, 512, kernel_size=3, padding=1),nn.ReLU(inplace=True),nn.Conv2d(512, 512, kernel_size=3, padding=1),nn.ReLU(inplace=True),nn.Conv2d(512, 512, kernel_size=3, padding=1),nn.ReLU(inplace=True),nn.MaxPool2d(kernel_size=2, stride=2),# Block 5nn.Conv2d(512, 512, kernel_size=3, padding=1),nn.ReLU(inplace=True),nn.Conv2d(512, 512, kernel_size=3, padding=1),nn.ReLU(inplace=True),nn.Conv2d(512, 512, kernel_size=3, padding=1),nn.ReLU(inplace=True),nn.MaxPool2d(kernel_size=2, stride=2),)self.classifier = nn.Sequential(nn.Linear(512 * 7 * 7, 4096),nn.ReLU(inplace=True),nn.Dropout(),nn.Linear(4096, 4096),nn.ReLU(inplace=True),nn.Dropout(),nn.Linear(4096, 2) # 输出层,二分类任务)def forward(self, x):x = self.features(x)x = torch.flatten(x, 1) # 展开特征图x = self.classifier(x)return x# 初始化VGG16模型
vgg16 = VGG16()
在上述代码中,我们定义了一个VGG16类,其中self.features部分包含了5个卷积块,self.classifier部分包含了3个全连接层。
1.2.3 训练和评估模型
import torch.optim as optim# 定义超参数
batch_size = 32
learning_rate = 0.001
num_epochs = 10model = VGG16()
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
model.to(device)# 定义损失函数和优化器
criterion = nn.CrossEntropyLoss()
optimizer = optim.SGD(model.parameters(), lr=learning_rate, momentum=0.9)# 训练模型
total_step = len(train_loader)
for epoch in range(num_epochs):for i, (images, labels) in enumerate(train_loader):images = images.to(device)labels = labels.to(device)# 前向传播outputs = model(images)loss = criterion(outputs, labels)# 反向传播和优化optimizer.zero_grad()loss.backward()optimizer.step()if (i + 1) % 100 == 0:print(f"Epoch [{epoch+1}/{num_epochs}], Step [{i+1}/{total_step}], Loss: {loss.item()}")
torch.save(model,'model/vgg16.pth')
# 测试模型
model.eval()
with torch.no_grad():correct = 0total = 0for images, labels in test_loader:images = images.to(device)labels = labels.to(device)outputs = model(images)print(outputs)_, predicted = torch.max(outputs.data, 1)total += labels.size(0)correct += (predicted == labels).sum().item()print(f"Accuracy on test images: {(correct / total) * 100}%")
在训练模型时,我们使用交叉熵损失函数(CrossEntropyLoss)作为分类任务的损失函数,并采用随机梯度下降(SGD)作为优化器。同时,我们将模型移动到GPU(如果可用)来加速训练过程。
相关文章:
Pytorch使用VGG16模型进行预测猫狗二分类
目录 1. VGG16 1.1 VGG16 介绍 1.1.1 VGG16 网络的整体结构 1.2 Pytorch使用VGG16进行猫狗二分类实战 1.2.1 数据集准备 1.2.2 构建VGG网络 1.2.3 训练和评估模型 1. VGG16 1.1 VGG16 介绍 深度学习已经在计算机视觉领域取得了巨大的成功,特别是在图像分类任…...
安装nvm使用nvm管理node切换npm镜像后使用vue ui管理构建项目成功
如果安装nvm前已经单独安装过node.js的请先自行卸载原有node和环境变量里面的配置; 亲测成功,有哪些问题可以在评论区发消息或者私聊我 1、安装nvm的步骤如下 下载nvm安装包 在nvm的GitHub仓库,如下是国内镜像仓库: 点击这里跳…...
在线LaTeX公式编辑器编辑公式
在线LaTeX公式编辑器编辑公式 在编辑LaTex文档时候,需要输入公式,可以使用在线LaTeX公式编辑器编辑公式,其链接为: 在线LaTeX公式编辑器,https://www.latexlive.com/home 图1 在线LaTeX公式编辑器界面 图2 在线LaTeX公式编辑器…...
【C、C++】学习0
C、C学习路线 C语法:变量、条件、循环、字符串、数组、函数、结构体等指针、内存管理推荐书籍:《C Primer Plus》、《C和指针》、《C专家编程》 CC语言基础C的面向对象(封装、继承与多态)特性泛型模板STL等等推荐书籍(…...
python GUI nicegui初识一(登录界面创建)
最近尝试了python的nicegui库,虽然可能也有一些不足,但个人感觉对于想要开发不过对ui设计感到很麻烦的人来说是很友好的了,毕竟nicegui可以利用TailwindCSS和Quasar进行ui开发,并且也支持定制自己的css样式。 这里记录一下自己利…...
【单片机】51单片机串口的收发实验,串口程序
这段代码是使用C语言编写的用于8051单片机的串口通信程序。它实现了以下功能: 引入必要的头文件,包括reg52.h、intrins.h、string.h、stdio.h和stdlib.h。 定义了常量FSOC和BAUD,分别表示系统时钟频率和波特率。 定义了一个发送数据的函数…...
【bug】记录一次使用Swiper插件时loop属性和slidersPerView属性冲突问题
简言 最近在vue3使用swiper时,突然发现loop属性和slides-per-view属性同时存在启用时,loop生效,下一步只能生效一次的bug,上一步却是好的。非常滴奇怪。 解决过程 分析属性是否使用错误。 loop是循环模式,布尔型。 …...
云原生应用里的服务发现
服务定义: 服务定义是声明给定服务如何被消费者/客户端使用的方式。在建立服务之间的同步通信通道之前,它会与消费者共享。 同步通信中的服务定义: 微服务可以将其服务定义发布到服务注册表(或由微服务所有者手动发布)…...
【零基础学Rust | 基础系列 | 基础语法】变量,数据类型,运算符,控制流
文章目录 简介:一,变量1,变量的定义2,变量的可变性3,变量的隐藏 二、数据类型1,标量类型2,复合类型 三,运算符1,算术运算符2,比较运算符3,逻辑运算…...
运输层---概述
目录 运输层主要内容一.概述和传输层服务1.1 概述1.2 传输服务和协议1.3 传输层 vs. 网络层1.4 Internet传输层协议 二. 多路复用与多路分解(解复用)2.1 概述2.2 无连接与面向连接的多路分解(解复用)2.3面向连接的多路复用*2.4 We…...
高速公路巡检无人机,为何成为公路巡检的主流工具
随着无人机技术的不断发展,无人机越来越多地应用于各个领域。其中,在高速公路领域,高速公路巡检无人机已成为公路巡检的得力助手。高速公路巡检无人机之所以能够成为公路巡检中的主流工具,主要是因为其具备以下三大特性。 一、高速…...
仓库管理系统有哪些功能,如何对仓库进行有效管理
阅读本文,您可以了解:1、仓库管理系统有哪些功能;2、如何对仓库进行有效管理。 仓库是制造业的开端,原材料的领料开始。企业的仓库管理是涉及企业生产、企业资金流和企业的经营风险的关键环节。在众多的工业企业、制造型企业、贸…...
Java 比Automic更高效的累加器
1、 java常见的原子类 类 Atomiclnteger、AtomicIntegerArray、AtomicIntegerFieldUpdater、AtomicLongArray、 AtomicLongFieldUpdater、AtomicReference、AtomicReferenceArray 和 AtomicReference- FieldUpdater 常见的原子类使用方法 使用 AtomicReference 来创建一个原…...
antDv table组件滚动截图方法的实现
在开发中经常遇到table内容过多产生滚动的场景,正常情况下不产生滚动进行截图就很好实现,一旦产生滚动就会变得有点棘手。 下面分两种场景阐述解决的方法过程 场景一:右侧不固定列的情况 场景二:右侧固定列的情况 场景一 打开…...
JavaSE【抽象类和接口】(抽象类、接口、实现多个接口、接口的继承)
一、抽象类 在 Java 中,一个类如果被 abstract 修饰称为抽象类,抽象类中被 abstract 修饰的方法称为抽象方法,抽象方法不用 给出具体的实现体。 1.语法 // 抽象类:被 abstract 修饰的类 public abstract class Shape { …...
微信小程序如何跳转H5页面
1、登录微信公众后台,进入【开发->开发管理->业务域名】,点击修改。 2、首先请下载校验文件,并将文件放置在域名根目录下。 我是把文件放在nginx主机的data目录下,然后通过增加nginx.config配置,重启nginx后可…...
C++(20):bit_cast
C++20之前如果想对不同的指针之间做类型转换需要通过reinterpret_cast,对于整数与指针之前的转换也需要通过reinterpret_cast: C++:reinterpret_cast_c++ reparant_cast_风静如云的博客-CSDN博客 但是reinterpret_cast的缺点是不同的编译环境下,无法包装转型的安全一致。 …...
STM32 低功耗-停止模式
STM32 停止模式 文章目录 STM32 停止模式第1章 低功耗模式简介第2章 停止模式简介2.1 进入停止模式2.1 退出停止模式 第3章 停止模式程序部分总结 第1章 低功耗模式简介 在 STM32 的正常工作中,具有四种工作模式:运行、睡眠、停止以及待机模式。 在系统…...
Hutool中 常用的工具类和方法
文章目录 日期时间工具类 DateUtil日期时间对象-DateTime类型转换工具类 Convert字符串工具类 StrUtil数字处理工具类 NumberUtilJavaBean的工具类 BeanUtil集合操作的工具类 CollUtilMap操作工具类 MapUtil数组工具-ArrayUtil唯一ID工具-IdUtilIO工具类-IoUtil加密解密工具类 …...
K8s(健康检查+滚动更新+优雅停机+弹性伸缩+Prometheus监控+配置分离)
前言 快速配置请直接跳转至汇总配置 K8s SpringBoot实现零宕机发布:健康检查滚动更新优雅停机弹性伸缩Prometheus监控配置分离(镜像复用) 配置 健康检查 健康检查类型:就绪探针(readiness) 存活探针&am…...
PPT|230页| 制造集团企业供应链端到端的数字化解决方案:从需求到结算的全链路业务闭环构建
制造业采购供应链管理是企业运营的核心环节,供应链协同管理在供应链上下游企业之间建立紧密的合作关系,通过信息共享、资源整合、业务协同等方式,实现供应链的全面管理和优化,提高供应链的效率和透明度,降低供应链的成…...
AtCoder 第409场初级竞赛 A~E题解
A Conflict 【题目链接】 原题链接:A - Conflict 【考点】 枚举 【题目大意】 找到是否有两人都想要的物品。 【解析】 遍历两端字符串,只有在同时为 o 时输出 Yes 并结束程序,否则输出 No。 【难度】 GESP三级 【代码参考】 #i…...
大数据零基础学习day1之环境准备和大数据初步理解
学习大数据会使用到多台Linux服务器。 一、环境准备 1、VMware 基于VMware构建Linux虚拟机 是大数据从业者或者IT从业者的必备技能之一也是成本低廉的方案 所以VMware虚拟机方案是必须要学习的。 (1)设置网关 打开VMware虚拟机,点击编辑…...
【OSG学习笔记】Day 16: 骨骼动画与蒙皮(osgAnimation)
骨骼动画基础 骨骼动画是 3D 计算机图形中常用的技术,它通过以下两个主要组件实现角色动画。 骨骼系统 (Skeleton):由层级结构的骨头组成,类似于人体骨骼蒙皮 (Mesh Skinning):将模型网格顶点绑定到骨骼上,使骨骼移动…...
selenium学习实战【Python爬虫】
selenium学习实战【Python爬虫】 文章目录 selenium学习实战【Python爬虫】一、声明二、学习目标三、安装依赖3.1 安装selenium库3.2 安装浏览器驱动3.2.1 查看Edge版本3.2.2 驱动安装 四、代码讲解4.1 配置浏览器4.2 加载更多4.3 寻找内容4.4 完整代码 五、报告文件爬取5.1 提…...
JS手写代码篇----使用Promise封装AJAX请求
15、使用Promise封装AJAX请求 promise就有reject和resolve了,就不必写成功和失败的回调函数了 const BASEURL ./手写ajax/test.jsonfunction promiseAjax() {return new Promise((resolve, reject) > {const xhr new XMLHttpRequest();xhr.open("get&quo…...
如何更改默认 Crontab 编辑器 ?
在 Linux 领域中,crontab 是您可能经常遇到的一个术语。这个实用程序在类 unix 操作系统上可用,用于调度在预定义时间和间隔自动执行的任务。这对管理员和高级用户非常有益,允许他们自动执行各种系统任务。 编辑 Crontab 文件通常使用文本编…...
C# 表达式和运算符(求值顺序)
求值顺序 表达式可以由许多嵌套的子表达式构成。子表达式的求值顺序可以使表达式的最终值发生 变化。 例如,已知表达式3*52,依照子表达式的求值顺序,有两种可能的结果,如图9-3所示。 如果乘法先执行,结果是17。如果5…...
uniapp 小程序 学习(一)
利用Hbuilder 创建项目 运行到内置浏览器看效果 下载微信小程序 安装到Hbuilder 下载地址 :开发者工具默认安装 设置服务端口号 在Hbuilder中设置微信小程序 配置 找到运行设置,将微信开发者工具放入到Hbuilder中, 打开后出现 如下 bug 解…...
【Linux】Linux安装并配置RabbitMQ
目录 1. 安装 Erlang 2. 安装 RabbitMQ 2.1.添加 RabbitMQ 仓库 2.2.安装 RabbitMQ 3.配置 3.1.启动和管理服务 4. 访问管理界面 5.安装问题 6.修改密码 7.修改端口 7.1.找到文件 7.2.修改文件 1. 安装 Erlang 由于 RabbitMQ 是用 Erlang 编写的,需要先安…...
