当前位置: 首页 > news >正文

Pytorch使用VGG16模型进行预测猫狗二分类

目录

1. VGG16

1.1 VGG16 介绍

1.1.1 VGG16 网络的整体结构

 1.2 Pytorch使用VGG16进行猫狗二分类实战

1.2.1 数据集准备

1.2.2 构建VGG网络

1.2.3 训练和评估模型


 

1. VGG16

1.1 VGG16 介绍

深度学习已经在计算机视觉领域取得了巨大的成功,特别是在图像分类任务中。VGG16是深度学习中经典的卷积神经网络(Convolutional Neural Network,CNN)之一,由牛津大学的Karen Simonyan和Andrew Zisserman在2014年提出。VGG16网络以其深度和简洁性而闻名,是图像分类中的重要里程碑。

VGG16是Visual Geometry Group的缩写,它的名字来源于提出该网络的实验室。VGG16的设计目标是通过增加网络深度来提高图像分类的性能,并展示了深度对于图像分类任务的重要性。VGG16的主要特点是将多个小尺寸的卷积核堆叠在一起,从而形成更深的网络。

1.1.1 VGG16 网络的整体结构

VGG16网络由多个卷积层和全连接层组成。它的整体结构相对简单,所有的卷积层都采用小尺寸的卷积核(通常为3x3),步幅为1,填充为1。每个卷积层后面都会跟着一个ReLU激活函数来引入非线性。

VGG16网络主要由三个部分组成:

  1. 输入层:接受图像输入,通常为224x224大小的彩色图像(RGB)。

  2. 卷积层:VGG16包含13个卷积层,其中包括五个卷积块。

  3. 全连接层:在卷积层后面是3个全连接层,用于最终的分类。

VGG16网络结构如下图:

watermark,type_ZmFuZ3poZW5naGVpdGk,shadow_10,text_aHR0cHM6Ly9ibG9nLmNzZG4ubmV0L3dlaXhpbl80NDc5MTk2NA==,size_16,color_FFFFFF,t_70#pic_center

1、一张原始图片被resize到(224,224,3)。
2、conv1两次[3,3]卷积网络,输出的特征层为64,输出为(224,224,64),再2X2最大池化,输出net为(112,112,64)。
3、conv2两次[3,3]卷积网络,输出的特征层为128,输出net为(112,112,128),再2X2最大池化,输出net为(56,56,128)。
4、conv3三次[3,3]卷积网络,输出的特征层为256,输出net为(56,56,256),再2X2最大池化,输出net为(28,28,256)。
5、conv4三次[3,3]卷积网络,输出的特征层为512,输出net为(28,28,512),再2X2最大池化,输出net为(14,14,512)。
6、conv5三次[3,3]卷积网络,输出的特征层为512,输出net为(14,14,512),再2X2最大池化,输出net为(7,7,512)。
7、利用卷积的方式模拟全连接层,效果等同,输出net为(1,1,4096)。共进行两次。
8、利用卷积的方式模拟全连接层,效果等同,输出net为(1,1,1000)。
最后输出的就是每个类的预测。

 1.2 Pytorch使用VGG16进行猫狗二分类实战

在这一部分,我们将使用PyTorch来实现VGG16网络,用于猫狗预测的二分类任务。我们将对VGG16的网络结构进行适当的修改,以适应我们的任务。

1.2.1 数据集准备

首先,我们需要准备用于猫狗二分类的数据集。数据集可以从Kaggle上下载,其中包含了大量的猫和狗的图片。在下载数据集后,我们需要将数据集划分为训练集和测试集。训练集文件夹命名为train,其中建立两个文件夹分别为cat和dog,每个文件夹里存放相应类别的图片。测试集命名为test,同理。

import torch
import torchvision
import torchvision.transforms as transforms# 定义数据转换
transform = transforms.Compose([transforms.Resize((224, 224)),transforms.ToTensor(),transforms.Normalize((0.5, 0.5, 0.5), (0.5, 0.5, 0.5))
])# 加载数据集
train_dataset = ImageFolder("train", transform=transform)
test_dataset = ImageFolder("test", transform=transform)train_loader = DataLoader(train_dataset, batch_size=batch_size, shuffle=True)
test_loader = DataLoader(test_dataset, batch_size=batch_size)

1.2.2 构建VGG网络

import torch.nn as nnclass VGG16(nn.Module):def __init__(self):super(VGG16, self).__init__()self.features = nn.Sequential(# Block 1nn.Conv2d(3, 64, kernel_size=3, padding=1),nn.ReLU(inplace=True),nn.Conv2d(64, 64, kernel_size=3, padding=1),nn.ReLU(inplace=True),nn.MaxPool2d(kernel_size=2, stride=2),# Block 2nn.Conv2d(64, 128, kernel_size=3, padding=1),nn.ReLU(inplace=True),nn.Conv2d(128, 128, kernel_size=3, padding=1),nn.ReLU(inplace=True),nn.MaxPool2d(kernel_size=2, stride=2),# Block 3nn.Conv2d(128, 256, kernel_size=3, padding=1),nn.ReLU(inplace=True),nn.Conv2d(256, 256, kernel_size=3, padding=1),nn.ReLU(inplace=True),nn.Conv2d(256, 256, kernel_size=3, padding=1),nn.ReLU(inplace=True),nn.MaxPool2d(kernel_size=2, stride=2),# Block 4nn.Conv2d(256, 512, kernel_size=3, padding=1),nn.ReLU(inplace=True),nn.Conv2d(512, 512, kernel_size=3, padding=1),nn.ReLU(inplace=True),nn.Conv2d(512, 512, kernel_size=3, padding=1),nn.ReLU(inplace=True),nn.MaxPool2d(kernel_size=2, stride=2),# Block 5nn.Conv2d(512, 512, kernel_size=3, padding=1),nn.ReLU(inplace=True),nn.Conv2d(512, 512, kernel_size=3, padding=1),nn.ReLU(inplace=True),nn.Conv2d(512, 512, kernel_size=3, padding=1),nn.ReLU(inplace=True),nn.MaxPool2d(kernel_size=2, stride=2),)self.classifier = nn.Sequential(nn.Linear(512 * 7 * 7, 4096),nn.ReLU(inplace=True),nn.Dropout(),nn.Linear(4096, 4096),nn.ReLU(inplace=True),nn.Dropout(),nn.Linear(4096, 2)  # 输出层,二分类任务)def forward(self, x):x = self.features(x)x = torch.flatten(x, 1)  # 展开特征图x = self.classifier(x)return x# 初始化VGG16模型
vgg16 = VGG16()

在上述代码中,我们定义了一个VGG16类,其中self.features部分包含了5个卷积块,self.classifier部分包含了3个全连接层。

1.2.3 训练和评估模型

import torch.optim as optim# 定义超参数
batch_size = 32
learning_rate = 0.001
num_epochs = 10model = VGG16()
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
model.to(device)# 定义损失函数和优化器
criterion = nn.CrossEntropyLoss()
optimizer = optim.SGD(model.parameters(), lr=learning_rate, momentum=0.9)# 训练模型
total_step = len(train_loader)
for epoch in range(num_epochs):for i, (images, labels) in enumerate(train_loader):images = images.to(device)labels = labels.to(device)# 前向传播outputs = model(images)loss = criterion(outputs, labels)# 反向传播和优化optimizer.zero_grad()loss.backward()optimizer.step()if (i + 1) % 100 == 0:print(f"Epoch [{epoch+1}/{num_epochs}], Step [{i+1}/{total_step}], Loss: {loss.item()}")
torch.save(model,'model/vgg16.pth')
# 测试模型
model.eval()
with torch.no_grad():correct = 0total = 0for images, labels in test_loader:images = images.to(device)labels = labels.to(device)outputs = model(images)print(outputs)_, predicted = torch.max(outputs.data, 1)total += labels.size(0)correct += (predicted == labels).sum().item()print(f"Accuracy on test images: {(correct / total) * 100}%")

在训练模型时,我们使用交叉熵损失函数(CrossEntropyLoss)作为分类任务的损失函数,并采用随机梯度下降(SGD)作为优化器。同时,我们将模型移动到GPU(如果可用)来加速训练过程。

 

 


 

 

相关文章:

Pytorch使用VGG16模型进行预测猫狗二分类

目录 1. VGG16 1.1 VGG16 介绍 1.1.1 VGG16 网络的整体结构 1.2 Pytorch使用VGG16进行猫狗二分类实战 1.2.1 数据集准备 1.2.2 构建VGG网络 1.2.3 训练和评估模型 1. VGG16 1.1 VGG16 介绍 深度学习已经在计算机视觉领域取得了巨大的成功,特别是在图像分类任…...

安装nvm使用nvm管理node切换npm镜像后使用vue ui管理构建项目成功

如果安装nvm前已经单独安装过node.js的请先自行卸载原有node和环境变量里面的配置; 亲测成功,有哪些问题可以在评论区发消息或者私聊我 1、安装nvm的步骤如下 下载nvm安装包 在nvm的GitHub仓库,如下是国内镜像仓库: 点击这里跳…...

在线LaTeX公式编辑器编辑公式

在线LaTeX公式编辑器编辑公式 在编辑LaTex文档时候,需要输入公式,可以使用在线LaTeX公式编辑器编辑公式,其链接为: 在线LaTeX公式编辑器,https://www.latexlive.com/home 图1 在线LaTeX公式编辑器界面 图2 在线LaTeX公式编辑器…...

【C、C++】学习0

C、C学习路线 C语法:变量、条件、循环、字符串、数组、函数、结构体等指针、内存管理推荐书籍:《C Primer Plus》、《C和指针》、《C专家编程》 CC语言基础C的面向对象(封装、继承与多态)特性泛型模板STL等等推荐书籍(…...

python GUI nicegui初识一(登录界面创建)

最近尝试了python的nicegui库,虽然可能也有一些不足,但个人感觉对于想要开发不过对ui设计感到很麻烦的人来说是很友好的了,毕竟nicegui可以利用TailwindCSS和Quasar进行ui开发,并且也支持定制自己的css样式。 这里记录一下自己利…...

【单片机】51单片机串口的收发实验,串口程序

这段代码是使用C语言编写的用于8051单片机的串口通信程序。它实现了以下功能: 引入必要的头文件,包括reg52.h、intrins.h、string.h、stdio.h和stdlib.h。 定义了常量FSOC和BAUD,分别表示系统时钟频率和波特率。 定义了一个发送数据的函数…...

【bug】记录一次使用Swiper插件时loop属性和slidersPerView属性冲突问题

简言 最近在vue3使用swiper时,突然发现loop属性和slides-per-view属性同时存在启用时,loop生效,下一步只能生效一次的bug,上一步却是好的。非常滴奇怪。 解决过程 分析属性是否使用错误。 loop是循环模式,布尔型。 …...

云原生应用里的服务发现

服务定义: 服务定义是声明给定服务如何被消费者/客户端使用的方式。在建立服务之间的同步通信通道之前,它会与消费者共享。 同步通信中的服务定义: 微服务可以将其服务定义发布到服务注册表(或由微服务所有者手动发布)…...

【零基础学Rust | 基础系列 | 基础语法】变量,数据类型,运算符,控制流

文章目录 简介:一,变量1,变量的定义2,变量的可变性3,变量的隐藏 二、数据类型1,标量类型2,复合类型 三,运算符1,算术运算符2,比较运算符3,逻辑运算…...

运输层---概述

目录 运输层主要内容一.概述和传输层服务1.1 概述1.2 传输服务和协议1.3 传输层 vs. 网络层1.4 Internet传输层协议 二. 多路复用与多路分解(解复用)2.1 概述2.2 无连接与面向连接的多路分解(解复用)2.3面向连接的多路复用*2.4 We…...

高速公路巡检无人机,为何成为公路巡检的主流工具

随着无人机技术的不断发展,无人机越来越多地应用于各个领域。其中,在高速公路领域,高速公路巡检无人机已成为公路巡检的得力助手。高速公路巡检无人机之所以能够成为公路巡检中的主流工具,主要是因为其具备以下三大特性。 一、高速…...

仓库管理系统有哪些功能,如何对仓库进行有效管理

阅读本文,您可以了解:1、仓库管理系统有哪些功能;2、如何对仓库进行有效管理。 仓库是制造业的开端,原材料的领料开始。企业的仓库管理是涉及企业生产、企业资金流和企业的经营风险的关键环节。在众多的工业企业、制造型企业、贸…...

Java 比Automic更高效的累加器

1、 java常见的原子类 类 Atomiclnteger、AtomicIntegerArray、AtomicIntegerFieldUpdater、AtomicLongArray、 AtomicLongFieldUpdater、AtomicReference、AtomicReferenceArray 和 AtomicReference- FieldUpdater 常见的原子类使用方法 使用 AtomicReference 来创建一个原…...

antDv table组件滚动截图方法的实现

在开发中经常遇到table内容过多产生滚动的场景,正常情况下不产生滚动进行截图就很好实现,一旦产生滚动就会变得有点棘手。 下面分两种场景阐述解决的方法过程 场景一:右侧不固定列的情况 场景二:右侧固定列的情况 场景一 打开…...

JavaSE【抽象类和接口】(抽象类、接口、实现多个接口、接口的继承)

一、抽象类 在 Java 中,一个类如果被 abstract 修饰称为抽象类,抽象类中被 abstract 修饰的方法称为抽象方法,抽象方法不用 给出具体的实现体。 1.语法 // 抽象类:被 abstract 修饰的类 public abstract class Shape { …...

微信小程序如何跳转H5页面

1、登录微信公众后台,进入【开发->开发管理->业务域名】,点击修改。 2、首先请下载校验文件,并将文件放置在域名根目录下。 我是把文件放在nginx主机的data目录下,然后通过增加nginx.config配置,重启nginx后可…...

C++(20):bit_cast

C++20之前如果想对不同的指针之间做类型转换需要通过reinterpret_cast,对于整数与指针之前的转换也需要通过reinterpret_cast: C++:reinterpret_cast_c++ reparant_cast_风静如云的博客-CSDN博客 但是reinterpret_cast的缺点是不同的编译环境下,无法包装转型的安全一致。 …...

STM32 低功耗-停止模式

STM32 停止模式 文章目录 STM32 停止模式第1章 低功耗模式简介第2章 停止模式简介2.1 进入停止模式2.1 退出停止模式 第3章 停止模式程序部分总结 第1章 低功耗模式简介 在 STM32 的正常工作中,具有四种工作模式:运行、睡眠、停止以及待机模式。 在系统…...

Hutool中 常用的工具类和方法

文章目录 日期时间工具类 DateUtil日期时间对象-DateTime类型转换工具类 Convert字符串工具类 StrUtil数字处理工具类 NumberUtilJavaBean的工具类 BeanUtil集合操作的工具类 CollUtilMap操作工具类 MapUtil数组工具-ArrayUtil唯一ID工具-IdUtilIO工具类-IoUtil加密解密工具类 …...

K8s(健康检查+滚动更新+优雅停机+弹性伸缩+Prometheus监控+配置分离)

前言 快速配置请直接跳转至汇总配置 K8s SpringBoot实现零宕机发布:健康检查滚动更新优雅停机弹性伸缩Prometheus监控配置分离(镜像复用) 配置 健康检查 健康检查类型:就绪探针(readiness) 存活探针&am…...

Vue记事本应用实现教程

文章目录 1. 项目介绍2. 开发环境准备3. 设计应用界面4. 创建Vue实例和数据模型5. 实现记事本功能5.1 添加新记事项5.2 删除记事项5.3 清空所有记事 6. 添加样式7. 功能扩展:显示创建时间8. 功能扩展:记事项搜索9. 完整代码10. Vue知识点解析10.1 数据绑…...

多场景 OkHttpClient 管理器 - Android 网络通信解决方案

下面是一个完整的 Android 实现&#xff0c;展示如何创建和管理多个 OkHttpClient 实例&#xff0c;分别用于长连接、普通 HTTP 请求和文件下载场景。 <?xml version"1.0" encoding"utf-8"?> <LinearLayout xmlns:android"http://schemas…...

如何在看板中体现优先级变化

在看板中有效体现优先级变化的关键措施包括&#xff1a;采用颜色或标签标识优先级、设置任务排序规则、使用独立的优先级列或泳道、结合自动化规则同步优先级变化、建立定期的优先级审查流程。其中&#xff0c;设置任务排序规则尤其重要&#xff0c;因为它让看板视觉上直观地体…...

DAY 47

三、通道注意力 3.1 通道注意力的定义 # 新增&#xff1a;通道注意力模块&#xff08;SE模块&#xff09; class ChannelAttention(nn.Module):"""通道注意力模块(Squeeze-and-Excitation)"""def __init__(self, in_channels, reduction_rat…...

1688商品列表API与其他数据源的对接思路

将1688商品列表API与其他数据源对接时&#xff0c;需结合业务场景设计数据流转链路&#xff0c;重点关注数据格式兼容性、接口调用频率控制及数据一致性维护。以下是具体对接思路及关键技术点&#xff1a; 一、核心对接场景与目标 商品数据同步 场景&#xff1a;将1688商品信息…...

pycharm 设置环境出错

pycharm 设置环境出错 pycharm 新建项目&#xff0c;设置虚拟环境&#xff0c;出错 pycharm 出错 Cannot open Local Failed to start [powershell.exe, -NoExit, -ExecutionPolicy, Bypass, -File, C:\Program Files\JetBrains\PyCharm 2024.1.3\plugins\terminal\shell-int…...

Python训练营-Day26-函数专题1:函数定义与参数

题目1&#xff1a;计算圆的面积 任务&#xff1a; 编写一个名为 calculate_circle_area 的函数&#xff0c;该函数接收圆的半径 radius 作为参数&#xff0c;并返回圆的面积。圆的面积 π * radius (可以使用 math.pi 作为 π 的值)要求&#xff1a;函数接收一个位置参数 radi…...

机器学习的数学基础:线性模型

线性模型 线性模型的基本形式为&#xff1a; f ( x ) ω T x b f\left(\boldsymbol{x}\right)\boldsymbol{\omega}^\text{T}\boldsymbol{x}b f(x)ωTxb 回归问题 利用最小二乘法&#xff0c;得到 ω \boldsymbol{\omega} ω和 b b b的参数估计$ \boldsymbol{\hat{\omega}}…...

多元隐函数 偏导公式

我们来推导隐函数 z z ( x , y ) z z(x, y) zz(x,y) 的偏导公式&#xff0c;给定一个隐函数关系&#xff1a; F ( x , y , z ( x , y ) ) 0 F(x, y, z(x, y)) 0 F(x,y,z(x,y))0 &#x1f9e0; 目标&#xff1a; 求 ∂ z ∂ x \frac{\partial z}{\partial x} ∂x∂z​、 …...

Git 命令全流程总结

以下是从初始化到版本控制、查看记录、撤回操作的 Git 命令全流程总结&#xff0c;按操作场景分类整理&#xff1a; 一、初始化与基础操作 操作命令初始化仓库git init添加所有文件到暂存区git add .提交到本地仓库git commit -m "提交描述"首次提交需配置身份git c…...