Pytorch迁移学习使用Resnet50进行模型训练预测猫狗二分类
目录
1.ResNet残差网络
1.1 ResNet定义
1.2 ResNet 几种网络配置
1.3 ResNet50网络结构
1.3.1 前几层卷积和池化
1.3.2 残差块:构建深度残差网络
1.3.3 ResNet主体:堆叠多个残差块
1.4 迁移学习猫狗二分类实战
1.4.1 迁移学习
1.4.2 模型训练
1.4.3 模型预测
1.ResNet残差网络
1.1 ResNet定义
深度学习在图像分类、目标检测、语音识别等领域取得了重大突破,但是随着网络层数的增加,梯度消失和梯度爆炸问题逐渐凸显。随着层数的增加,梯度信息在反向传播过程中逐渐变小,导致网络难以收敛。同时,梯度爆炸问题也会导致网络的参数更新过大,无法正常收敛。
为了解决这些问题,ResNet提出了一个创新的思路:引入残差块(Residual Block)。残差块的设计允许网络学习残差映射,从而减轻了梯度消失问题,使得网络更容易训练。
下图是一个基本残差块。它的操作是把某层输入跳跃连接到下一层乃至更深层的激活层之前,同本层输出一起经过激活函数输出。
1.2 ResNet 几种网络配置
如下图:
1.3 ResNet50网络结构
ResNet-50是一个具有50个卷积层的深度残差网络。它的网络结构非常复杂,但我们可以将其分为以下几个模块:
1.3.1 前几层卷积和池化
import torch
import torch.nn as nnclass ResNet50(nn.Module):def __init__(self, num_classes=1000):super(ResNet50, self).__init__()self.conv1 = nn.Conv2d(in_channels=3, out_channels=64, kernel_size=7, stride=2, padding=3, bias=False)self.bn1 = nn.BatchNorm2d(64)self.relu = nn.ReLU(inplace=True)self.maxpool = nn.MaxPool2d(kernel_size=3, stride=2, padding=1)
1.3.2 残差块:构建深度残差网络
class ResidualBlock(nn.Module):def __init__(self, in_channels, out_channels, stride=1, downsample=None):super(ResidualBlock, self).__init__()self.conv1 = nn.Conv2d(in_channels, out_channels, kernel_size=1, stride=stride, bias=False)self.bn1 = nn.BatchNorm2d(out_channels)self.conv2 = nn.Conv2d(out_channels, out_channels, kernel_size=3, stride=1, padding=1, bias=False)self.bn2 = nn.BatchNorm2d(out_channels)self.conv3 = nn.Conv2d(out_channels, out_channels * 4, kernel_size=1, stride=1, bias=False)self.bn3 = nn.BatchNorm2d(out_channels * 4)self.relu = nn.ReLU(inplace=True)self.downsample = downsampledef forward(self, x):identity = xout = self.conv1(x)out = self.bn1(out)out = self.relu(out)out = self.conv2(out)out = self.bn2(out)out = self.relu(out)out = self.conv3(out)out = self.bn3(out)if self.downsample is not None:identity = self.downsample(x)out += identityout = self.relu(out)return out
1.3.3 ResNet主体:堆叠多个残差块
在ResNet-50中,我们堆叠了多个残差块来构建整个网络。每个残差块会将输入的特征图进行处理,并输出更加丰富的特征图。堆叠多个残差块允许网络在深度方向上进行信息的层层提取,从而获得更高级的语义信息。代码如下:
class ResNet50(nn.Module):def __init__(self, num_classes=1000):# ... 前几层代码 ...# 4个残差块的block1self.layer1 = self._make_layer(ResidualBlock, 64, 3, stride=1)# 4个残差块的block2self.layer2 = self._make_layer(ResidualBlock, 128, 4, stride=2)# 4个残差块的block3self.layer3 = self._make_layer(ResidualBlock, 256, 6, stride=2)# 4个残差块的block4self.layer4 = self._make_layer(ResidualBlock, 512, 3, stride=2)
利用make_layer函数实现对基本残差块Bottleneck的堆叠。代码如下:
def _make_layer(self, block, channel, block_num, stride=1):"""block: 堆叠的基本模块channel: 每个stage中堆叠模块的第一个卷积的卷积核个数,对resnet50分别是:64,128,256,512block_num: 当期stage堆叠block个数stride: 默认卷积步长"""downsample = None # 用于控制shorcut路的if stride != 1 or self.in_channel != channel*block.expansion: # 对resnet50:conv2中特征图尺寸H,W不需要下采样/2,但是通道数x4,因此shortcut通道数也需要x4。对其余conv3,4,5,既要特征图尺寸H,W/2,又要shortcut维度x4downsample = nn.Sequential(nn.Conv2d(in_channels=self.in_channel, out_channels=channel*block.expansion, kernel_size=1, stride=stride, bias=False), # out_channels决定输出通道数x4,stride决定特征图尺寸H,W/2nn.BatchNorm2d(num_features=channel*block.expansion))layers = [] # 每一个convi_x的结构保存在一个layers列表中,i={2,3,4,5}layers.append(block(in_channel=self.in_channel, out_channel=channel, downsample=downsample, stride=stride)) # 定义convi_x中的第一个残差块,只有第一个需要设置downsample和strideself.in_channel = channel*block.expansion # 在下一次调用_make_layer函数的时候,self.in_channel已经x4for _ in range(1, block_num): # 通过循环堆叠其余残差块(堆叠了剩余的block_num-1个)layers.append(block(in_channel=self.in_channel, out_channel=channel))return nn.Sequential(*layers) # '*'的作用是将list转换为非关键字参数传入
1.4 迁移学习猫狗二分类实战
1.4.1 迁移学习
迁移学习(Transfer Learning)是一种机器学习和深度学习技术,它允许我们将一个任务学到的知识或特征迁移到另一个相关的任务中,从而加速模型的训练和提高性能。在迁移学习中,我们通常利用已经在大规模数据集上预训练好的模型(称为源任务模型),将其权重用于新的任务(称为目标任务),而不是从头开始训练一个全新的模型。
迁移学习的核心思想是:在解决一个新任务之前,我们可以先从已经学习过的任务中获取一些通用的特征或知识,并将这些特征或知识迁移到新任务中。这样做的好处在于,源任务模型已经在大规模数据集上进行了充分训练,学到了很多通用的特征,例如边缘检测、纹理等,这些特征对于许多任务都是有用的。
1.4.2 模型训练
首先,我们需要准备用于猫狗二分类的数据集。数据集可以从Kaggle上下载,其中包含了大量的猫和狗的图片。
在下载数据集后,我们需要将数据集划分为训练集和测试集。训练集文件夹命名为train,其中建立两个文件夹分别为cat和dog,每个文件夹里存放相应类别的图片。测试集命名为test,同理。然后我们使用ResNet50网络模型,在我们的计算机上使用GPU进行训练并保存我们的模型,训练完成后在测试集上验证模型预测的正确率。
import torch
import torch.nn as nn
import torch.optim as optim
import torchvision.transforms as transforms
from torch.utils.data import DataLoader, Dataset
from torchvision.datasets import ImageFolder
from torchvision.models import resnet50# 设置随机种子
torch.manual_seed(42)# 定义超参数
batch_size = 32
learning_rate = 0.001
num_epochs = 10# 定义数据转换
transform = transforms.Compose([transforms.Resize((224, 224)),transforms.ToTensor(),transforms.Normalize((0.5, 0.5, 0.5), (0.5, 0.5, 0.5))
])# 加载数据集
train_dataset = ImageFolder("train", transform=transform)
test_dataset = ImageFolder("test", transform=transform)train_loader = DataLoader(train_dataset, batch_size=batch_size, shuffle=True)
test_loader = DataLoader(test_dataset, batch_size=batch_size)# 加载预训练的ResNet-50模型
model = resnet50(pretrained=True)
num_ftrs = model.fc.in_features
model.fc = nn.Linear(num_ftrs, 2) # 替换最后一层全连接层,以适应二分类问题device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
model.to(device)# 定义损失函数和优化器
criterion = nn.CrossEntropyLoss()
optimizer = optim.SGD(model.parameters(), lr=learning_rate, momentum=0.9)# 训练模型
total_step = len(train_loader)
for epoch in range(num_epochs):for i, (images, labels) in enumerate(train_loader):images = images.to(device)labels = labels.to(device)# 前向传播outputs = model(images)loss = criterion(outputs, labels)# 反向传播和优化optimizer.zero_grad()loss.backward()optimizer.step()if (i + 1) % 100 == 0:print(f"Epoch [{epoch+1}/{num_epochs}], Step [{i+1}/{total_step}], Loss: {loss.item()}")
torch.save(model,'model/c.pth')
# 测试模型
model.eval()
with torch.no_grad():correct = 0total = 0for images, labels in test_loader:images = images.to(device)labels = labels.to(device)outputs = model(images)print(outputs)_, predicted = torch.max(outputs.data, 1)total += labels.size(0)correct += (predicted == labels).sum().item()breakprint(f"Accuracy on test images: {(correct / total) * 100}%")
1.4.3 模型预测
首先加载我们保存的模型,这里我们进行单张图片的预测,并把预测结果打印日志。
import cv2 as cv
import matplotlib.pyplot as plt
import matplotlib.image as mpimg
import torchvision.transforms as transforms
import torch
from PIL import Image
import os
os.environ['KMP_DUPLICATE_LIB_OK'] = 'TRUE'
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
model=torch.load('model/c.pth')
print(model)
model.to(device)test_image_path = 'test/dogs/dog.4001.jpg' # Replace with your test image path
image = Image.open(test_image_path)
transform = transforms.Compose([transforms.Resize((224, 224)),transforms.ToTensor(),transforms.Normalize((0.5, 0.5, 0.5), (0.5, 0.5, 0.5))
])
input_tensor = transform(image).unsqueeze(0).to(device) # Add a batch dimension and move to GPU# Set the model to evaluation mode
model.eval()with torch.no_grad():outputs = model(input_tensor)_, predicted = torch.max(outputs, 1)predicted_label = predicted.item()label=['猫','狗']
print(label[predicted_label])
plt.axis('off')
plt.imshow(image)
plt.show()
运行截图
至此这篇文章到此结束。
相关文章:

Pytorch迁移学习使用Resnet50进行模型训练预测猫狗二分类
目录 1.ResNet残差网络 1.1 ResNet定义 1.2 ResNet 几种网络配置 1.3 ResNet50网络结构 1.3.1 前几层卷积和池化 1.3.2 残差块:构建深度残差网络 1.3.3 ResNet主体:堆叠多个残差块 1.4 迁移学习猫狗二分类实战 1.4.1 迁移学习 1.4.2 模型训练 1.…...
HTML与XHTML的不同和各自特点
HTML和XHTML都是用于创建Web页面的标记语言。HTML是一种被广泛使用的标记语言,而XHTML是HTML的严格规范化版本。在本文中,我们将探讨HTML与XHTML之间的不同之处,以及它们各自的特点。 HTML与XHTML的不同之处 HTML和XHTML之间最大的不同在于它…...

微服务如何治理
微服务远程调用可能有如下问题: 注册中心宕机; 服务提供者B有节点宕机; 服务消费者A和注册中心之间的网络不通; 服务提供者B和注册中心之间的网络不通; 服务消费者A和服务提供者B之间的网络不通; 服务提供者…...

一本通1919:【02NOIP普及组】选数
这道题感觉很好玩。 正文: 先放题目: 信息学奥赛一本通(C版)在线评测系统 (ssoier.cn)http://ybt.ssoier.cn:8088/problem_show.php?pid1919 描述 已知 n 个整数 x1,x2,…,xn,以及一个整数 k(k&#…...
Kubernetes 集群管理和编排
文章目录 总纲第一章:引入 Kubernetes什么是容器编排和管理?容器编排和管理的重要性Kubernetes作为容器编排和管理解决方案 Kubernetes 的背景和发展起源和发展历程Kubernetes 项目的目标和动机 Kubernetes 的作用和优势作用优势 Kubernetes 的特点和核心…...
DDS协议--[第六章][Discovery]
DDS协议–Discovery 文章目录 DDS协议--Discovery侦听通告DDS提供发现协议参与者发现阶段(PDP)端点发现阶段(EDP)Fast DDS提供如下四种发现机制:简单发现机制简单发现机制步骤:侦听 侦听定位器用于接收DomainParticipant上的传入流量,是DDS发现机制和数据传输机制的关键…...
如何设置iptables,让网络流量转发给内部容器mysql
1.创建一个mysql ,无法外部访问 docker run -d --name mysql_container -e MYSQL_ROOT_PASSWORDliuyunshengsir -v /path/to/mysql_data:/var/lib/mysql mysql2.设置规则外部直接可访问 要使用 iptables 将网络流量转发给内部容器中的 MySQL 服务,你可…...

数字IC实践项目(7)—CNN加速器的设计和实现(付费项目)
数字IC实践项目(7)—基于Verilog的CNN加速器(付费项目) 写在前面的话项目整体框图神经网络框图完整电路框图 项目简介和学习目的软件环境要求 资源占用&板载功耗总结 写在前面的话 项目介绍: 卷积神经网络硬件加速…...

基于深度学习的高精度80类动物目标检测系统(PyTorch+Pyside6+YOLOv5模型)
摘要:基于深度学习的高精度80类动物目标检测识别系统可用于日常生活中或野外来检测与定位80类动物目标,利用深度学习算法可实现图片、视频、摄像头等方式的80类动物目标检测识别,另外支持结果可视化与图片或视频检测结果的导出。本系统采用YO…...

海康摄像头开发笔记(一):连接防爆摄像头、配置摄像头网段、设置rtsp码流、播放rtsp流、获取rtsp流、调优rtsp流播放延迟以及录像存储
文为原创文章,转载请注明原文出处 本文章博客地址:https://hpzwl.blog.csdn.net/article/details/131679108 红胖子(红模仿)的博文大全:开发技术集合(包含Qt实用技术、树莓派、三维、OpenCV、OpenGL、ffmpeg、OSG、单片机、软硬结…...
【NCNN】NCNN中Mat与CV中Mat的使用区别及相互转换方法
目录 相同点与不同点cv::Mat转ncnn::Matcv::Mat CV_8UC3 -> ncnn::Mat 3 channel swap RGB/BGRcv::Mat CV_8UC3 -> ncnn::Mat 1 channel do RGB2GRAY/BGR2GRAYcv::Mat CV_8UC1 -> ncnn::Mat 1 channel ncnn::Mat转cv::Mancnn::Mat 3 channel -> cv::Mat CV_8UC3 …...
Android 13 设置自动进入wifi adb模式
Android 13 设置自动进入wifi adb模式 文章目录 Android 13 设置自动进入wifi adb模式一、前言:二、解决Android 13 wifi adb每次重启自动重置问题方法1、分析系统中每次重置wifi adb属性的代码2、在开机广播里面进行设置wifi adb 相关属性(1)…...
(笔记)插入排序
插入排序 插入排序是一种简单且常见的排序算法,它通过重复将一个元素插入到已经排好序的一组元素中,来达到排序的目的。在插入排序算法中,将待排序序列分为已排序和未排序两个部分。初始时,已排序部分只包含一个记录,…...

结构型模式 - 组合模式
概述 对于这个图片肯定会非常熟悉,上图我们可以看做是一个文件系统,对于这样的结构我们称之为树形结构。在树形结构中可以通过调用某个方法来遍历整个树,当我们找到某个叶子节点后,就可以对叶子节点进行相关的操作。可以将这颗树理…...

EDM营销过时了?不,这才是跨境电商成功的最佳工具
根据最近的一项研究,电子邮件仍然是最具说服力的营销工具和沟通形式之一。虽然即时通讯等其他渠道正在扎根,但电子邮件仍然是影响最深远的商业交流形式。到2023年,每天发送和接收的电子邮件总数可能会超过333亿封。所以,如果您希望…...
【大数据之Hive】二十五、HQL语法优化之小文件合并
1 优化说明 小文件优化可以从两个方面解决,在Map端输入的小文件合并,在Reduce端输出的小文件合并。 1.1 Map端输入文件合并 合并Map端输入的小文件是指将多个小文件分到同一个切片中,由一个Map Task处理,防止单个小文件启动一个M…...
spring 连接oracle数据库报错{dataSource-1} init error解决,电脑用户名问题
错误描述: 连接oracle数据就报错,同样的代码其他电脑不会报错。 报错如下: {dataSource-1} init error java.sql.SQLRecoverableException: IO 错误: Undefined Error com.alibaba.druid.pool.DruidDataSource-1049[main]ERROR: {dataSourc…...
行业视野::人工智能与机器人
控制和机器人领域非常重要的quote:莫拉维克悖论(Moravecs paradox) It is comparatively easy to make computers exhibit adult level performance on intelligence tests or playing checkers,and difficult or impossible to give them th…...
【Python入门系列】第十七篇:Python大数据处理和分析
【Python入门系列】第十七篇:Python大数据处理和分析 文章目录 前言一、数据处理和分析步骤二、Python大数据处理和分析库三、Python大数据处理和分析应用1、数据清洗和转换2、数据分析和统计3、数据可视化4、机器学习模型训练和预测5、大规模数据处理和分布式计算6…...

spring.profiles的使用详解
本文来说下spring.profiles.active和spring.profiles.include的使用与区别 文章目录 业务场景spring.profiles.active属性启动时指定 spring.profiles.include属性配置方法配置位置配置区别 用示例来使用和区分测试一测试二测试三 编写程序查看激活的yml文件本文小结 业务场景 …...

剑指offer20_链表中环的入口节点
链表中环的入口节点 给定一个链表,若其中包含环,则输出环的入口节点。 若其中不包含环,则输出null。 数据范围 节点 val 值取值范围 [ 1 , 1000 ] [1,1000] [1,1000]。 节点 val 值各不相同。 链表长度 [ 0 , 500 ] [0,500] [0,500]。 …...
浅谈不同二分算法的查找情况
二分算法原理比较简单,但是实际的算法模板却有很多,这一切都源于二分查找问题中的复杂情况和二分算法的边界处理,以下是博主对一些二分算法查找的情况分析。 需要说明的是,以下二分算法都是基于有序序列为升序有序的情况…...

如何在网页里填写 PDF 表格?
有时候,你可能希望用户能在你的网站上填写 PDF 表单。然而,这件事并不简单,因为 PDF 并不是一种原生的网页格式。虽然浏览器可以显示 PDF 文件,但原生并不支持编辑或填写它们。更糟的是,如果你想收集表单数据ÿ…...

AI病理诊断七剑下天山,医疗未来触手可及
一、病理诊断困局:刀尖上的医学艺术 1.1 金标准背后的隐痛 病理诊断被誉为"诊断的诊断",医生需通过显微镜观察组织切片,在细胞迷宫中捕捉癌变信号。某省病理质控报告显示,基层医院误诊率达12%-15%,专家会诊…...
C++.OpenGL (14/64)多光源(Multiple Lights)
多光源(Multiple Lights) 多光源渲染技术概览 #mermaid-svg-3L5e5gGn76TNh7Lq {font-family:"trebuchet ms",verdana,arial,sans-serif;font-size:16px;fill:#333;}#mermaid-svg-3L5e5gGn76TNh7Lq .error-icon{fill:#552222;}#mermaid-svg-3L5e5gGn76TNh7Lq .erro…...
C#中的CLR属性、依赖属性与附加属性
CLR属性的主要特征 封装性: 隐藏字段的实现细节 提供对字段的受控访问 访问控制: 可单独设置get/set访问器的可见性 可创建只读或只写属性 计算属性: 可以在getter中执行计算逻辑 不需要直接对应一个字段 验证逻辑: 可以…...

从“安全密码”到测试体系:Gitee Test 赋能关键领域软件质量保障
关键领域软件测试的"安全密码":Gitee Test如何破解行业痛点 在数字化浪潮席卷全球的今天,软件系统已成为国家关键领域的"神经中枢"。从国防军工到能源电力,从金融交易到交通管控,这些关乎国计民生的关键领域…...
HybridVLA——让单一LLM同时具备扩散和自回归动作预测能力:训练时既扩散也回归,但推理时则扩散
前言 如上一篇文章《dexcap升级版之DexWild》中的前言部分所说,在叠衣服的过程中,我会带着团队对比各种模型、方法、策略,毕竟针对各个场景始终寻找更优的解决方案,是我个人和我司「七月在线」的职责之一 且个人认为,…...

ZYNQ学习记录FPGA(一)ZYNQ简介
一、知识准备 1.一些术语,缩写和概念: 1)ZYNQ全称:ZYNQ7000 All Pgrammable SoC 2)SoC:system on chips(片上系统),对比集成电路的SoB(system on board) 3)ARM:处理器…...
[USACO23FEB] Bakery S
题目描述 Bessie 开了一家面包店! 在她的面包店里,Bessie 有一个烤箱,可以在 t C t_C tC 的时间内生产一块饼干或在 t M t_M tM 单位时间内生产一块松糕。 ( 1 ≤ t C , t M ≤ 10 9 ) (1 \le t_C,t_M \le 10^9) (1≤tC,tM≤109)。由于空间…...