当前位置: 首页 > news >正文

PyTorch--残差网络(ResNet)在CIFAR-10数据集进行图像分类

完整代码

import torch
import torch.nn as nn
import torchvision
import torchvision.transforms as transforms# Device configuration
device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')# Hyper-parameters
num_epochs = 80
batch_size = 100
learning_rate = 0.001# Image preprocessing modules
transform = transforms.Compose([transforms.Pad(4),transforms.RandomHorizontalFlip(),transforms.RandomCrop(32),transforms.ToTensor()])# CIFAR-10 dataset
train_dataset = torchvision.datasets.CIFAR10(root='../../data/',train=True, transform=transform,download=True)test_dataset = torchvision.datasets.CIFAR10(root='../../data/',train=False, transform=transforms.ToTensor())# Data loader
train_loader = torch.utils.data.DataLoader(dataset=train_dataset,batch_size=batch_size,shuffle=True)test_loader = torch.utils.data.DataLoader(dataset=test_dataset,batch_size=batch_size,shuffle=False)# 3x3 convolution
def conv3x3(in_channels, out_channels, stride=1):return nn.Conv2d(in_channels, out_channels, kernel_size=3, stride=stride, padding=1, bias=False)# Residual block
class ResidualBlock(nn.Module):def __init__(self, in_channels, out_channels, stride=1, downsample=None):super(ResidualBlock, self).__init__()self.conv1 = conv3x3(in_channels, out_channels, stride)self.bn1 = nn.BatchNorm2d(out_channels)self.relu = nn.ReLU(inplace=True)self.conv2 = conv3x3(out_channels, out_channels)self.bn2 = nn.BatchNorm2d(out_channels)self.downsample = downsampledef forward(self, x):residual = xout = self.conv1(x)out = self.bn1(out)out = self.relu(out)out = self.conv2(out)out = self.bn2(out)if self.downsample:residual = self.downsample(x)out += residualout = self.relu(out)return out# ResNet
class ResNet(nn.Module):def __init__(self, block, layers, num_classes=10):super(ResNet, self).__init__()self.in_channels = 16self.conv = conv3x3(3, 16)self.bn = nn.BatchNorm2d(16)self.relu = nn.ReLU(inplace=True)self.layer1 = self.make_layer(block, 16, layers[0])self.layer2 = self.make_layer(block, 32, layers[1], 2)self.layer3 = self.make_layer(block, 64, layers[2], 2)self.avg_pool = nn.AvgPool2d(8)self.fc = nn.Linear(64, num_classes)def make_layer(self, block, out_channels, blocks, stride=1):downsample = Noneif (stride != 1) or (self.in_channels != out_channels):downsample = nn.Sequential(conv3x3(self.in_channels, out_channels, stride=stride),nn.BatchNorm2d(out_channels))layers = []layers.append(block(self.in_channels, out_channels, stride, downsample))self.in_channels = out_channelsfor i in range(1, blocks):layers.append(block(out_channels, out_channels))return nn.Sequential(*layers)def forward(self, x):out = self.conv(x)out = self.bn(out)out = self.relu(out)out = self.layer1(out)out = self.layer2(out)out = self.layer3(out)out = self.avg_pool(out)out = out.view(out.size(0), -1)out = self.fc(out)return outmodel = ResNet(ResidualBlock, [2, 2, 2]).to(device)# Loss and optimizer
criterion = nn.CrossEntropyLoss()
optimizer = torch.optim.Adam(model.parameters(), lr=learning_rate)# For updating learning rate
def update_lr(optimizer, lr):    for param_group in optimizer.param_groups:param_group['lr'] = lr# Train the model
total_step = len(train_loader)
curr_lr = learning_rate
for epoch in range(num_epochs):for i, (images, labels) in enumerate(train_loader):images = images.to(device)labels = labels.to(device)# Forward passoutputs = model(images)loss = criterion(outputs, labels)# Backward and optimizeoptimizer.zero_grad()loss.backward()optimizer.step()if (i+1) % 100 == 0:print ("Epoch [{}/{}], Step [{}/{}] Loss: {:.4f}".format(epoch+1, num_epochs, i+1, total_step, loss.item()))# Decay learning rateif (epoch+1) % 20 == 0:curr_lr /= 3update_lr(optimizer, curr_lr)# Test the model
model.eval()
with torch.no_grad():correct = 0total = 0for images, labels in test_loader:images = images.to(device)labels = labels.to(device)outputs = model(images)_, predicted = torch.max(outputs.data, 1)total += labels.size(0)correct += (predicted == labels).sum().item()print('Accuracy of the model on the test images: {} %'.format(100 * correct / total))# Save the model checkpoint
torch.save(model.state_dict(), 'resnet.ckpt')

这段代码是一个PyTorch实现的残差网络(ResNet),用于在CIFAR-10数据集上进行图像分类任务。下面是代码的详细解析:

代码解析

导入必要的库

import torch
import torch.nn as nn
import torchvision
import torchvision.transforms as transforms

导入PyTorch及其神经网络模块、torchvision库用于处理图像数据。

设备配置

device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')

设置运行设备,优先使用GPU,如果没有GPU,则使用CPU。

超参数设置

num_epochs = 80
batch_size = 100
learning_rate = 0.001

设置训练轮数、批次大小和学习率。

数据预处理

transform = transforms.Compose([transforms.Pad(4),transforms.RandomHorizontalFlip(),transforms.RandomCrop(32),transforms.ToTensor()])

定义数据预处理步骤,包括填充、随机水平翻转、随机裁剪和转换为张量。

加载CIFAR-10数据集

train_dataset = torchvision.datasets.CIFAR10(root='../../data/', train=True, transform=transform, download=True)
test_dataset = torchvision.datasets.CIFAR10(root='../../data/', train=False, transform=transforms.ToTensor())

加载CIFAR-10训练集和测试集。

创建数据加载器

train_loader = torch.utils.data.DataLoader(dataset=train_dataset, batch_size=batch_size, shuffle=True)
test_loader = torch.utils.data.DataLoader(dataset=test_dataset, batch_size=batch_size, shuffle=False)

创建用于加载数据的DataLoader。

定义3x3卷积函数

def conv3x3(in_channels, out_channels, stride=1):return nn.Conv2d(in_channels, out_channels, kernel_size=3, stride=stride, padding=1, bias=False)

定义一个3x3的卷积层。

定义残差块

class ResidualBlock(nn.Module):# ...

定义残差网络中的残差块,包含两个卷积层和批量归一化层。

定义ResNet模型

class ResNet(nn.Module):# ...

定义ResNet模型,使用残差块构建多个层。

实例化模型并移动到设备

model = ResNet(ResidualBlock, [2, 2, 2]).to(device)

创建ResNet模型实例并将其移动到配置的设备上。

定义损失函数和优化器

criterion = nn.CrossEntropyLoss()
optimizer = torch.optim.Adam(model.parameters(), lr=learning_rate)

定义交叉熵损失函数和Adam优化器。

学习率衰减函数

def update_lr(optimizer, lr):    for param_group in optimizer.param_groups:param_group['lr'] = lr

定义一个函数用于更新优化器的学习率。

训练模型

for epoch in range(num_epochs):# ...# 每20个epoch衰减学习率if (epoch+1) % 20 == 0:curr_lr /= 3update_lr(optimizer, curr_lr)

执行训练循环,包括前向传播、损失计算、反向传播和参数更新,并每20个epoch衰减学习率。

测试模型

model.eval()
with torch.no_grad():# ...

在测试阶段,设置模型为评估模式,并计算准确率。

保存模型

torch.save(model.state_dict(), 'resnet.ckpt')

保存模型的状态字典。

这段代码实现了一个标准的ResNet架构,用于CIFAR-10数据集的分类任务。代码中包含了数据预处理、模型定义、训练过程、测试评估和模型保存等关键步骤。

常见函数及其用法

以下是代码中使用的常见函数及其解析:

  1. torch.device

    • 格式:torch.device(device_str)
    • 参数:device_str —— 设备类型字符串(如’cuda’或’cpu’)。
    • 意义:确定模型和张量运行的设备。
    • 用法示例:device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
  2. torchvision.datasets.CIFAR10

    • 格式:torchvision.datasets.CIFAR10(root, train, transform, download)
    • 参数:指定数据集的路径、是否为训练集、预处理变换、是否下载数据集。
    • 意义:加载CIFAR-10数据集。
    • 用法示例:train_dataset = torchvision.datasets.CIFAR10(root='../../data/', train=True, transform=transform, download=True)
  3. torch.utils.data.DataLoader

    • 格式:torch.utils.data.DataLoader(dataset, batch_size, shuffle)
    • 参数:数据集对象、批次大小、是否打乱数据。
    • 意义:创建数据加载器,用于批量加载数据。
    • 用法示例:train_loader = torch.utils.data.DataLoader(dataset=train_dataset, batch_size=batch_size, shuffle=True)
  4. nn.Conv2d

    • 格式:nn.Conv2d(in_channels, out_channels, kernel_size, stride, padding)
    • 参数:输入通道数、输出通道数、卷积核大小、步长、填充。
    • 意义:创建二维卷积层。
    • 用法示例:return nn.Conv2d(in_channels, out_channels, kernel_size=3, stride=stride, padding=1, bias=False)
  5. nn.BatchNorm2d

    • 格式:nn.BatchNorm2d(num_features)
    • 参数:特征数量。
    • 意义:创建二维批量归一化层。
    • 用法示例:self.bn1 = nn.BatchNorm2d(out_channels)
  6. nn.ReLU

    • 格式:nn.ReLU(inplace=True/False)
    • 参数:是否使用内存原地(inplace)优化。
    • 意义:创建ReLU激活层。
    • 用法示例:self.relu = nn.ReLU(inplace=True)
  7. nn.Sequential

    • 格式:nn.Sequential(*modules)
    • 参数:一个模块序列。
    • 意义:按顺序应用多个模块。
    • 用法示例:downsample = nn.Sequential(conv3x3(self.in_channels, out_channels, stride=stride), nn.BatchNorm2d(out_channels))
  8. nn.CrossEntropyLoss

    • 格式:nn.CrossEntropyLoss()
    • 意义:创建交叉熵损失层,用于多分类问题。
    • 用法示例:criterion = nn.CrossEntropyLoss()
  9. torch.optim.Adam

    • 格式:torch.optim.Adam(params, lr)
    • 参数:模型参数、学习率。
    • 意义:创建Adam优化器。
    • 用法示例:optimizer = torch.optim.Adam(model.parameters(), lr=learning_rate)
  10. .to(device)

    • 格式:.to(device)
    • 参数:设备对象。
    • 意义:将模型或张量移动到指定设备。
    • 用法示例:images = images.to(device)
  11. view

    • 格式:view(size)
    • 参数:新的大小。
    • 意义:重塑张量。
    • 用法示例:out = out.view(out.size(0), -1)
  12. max

    • 格式:max(dim, keepdim)
    • 参数:计算最大值的维度、是否保持维度。
    • 意义:计算并返回张量在指定维度上的最大值和索引。
    • 用法示例:_, predicted = torch.max(outputs.data, 1)
  13. no_grad

    • 格式:torch.no_grad()
    • 意义:上下文管理器,用于禁用梯度计算。
    • 用法示例:with torch.no_grad():
  14. sum

    • 格式:sum(dim, keepdim)
    • 参数:求和的维度、是否保持维度。
    • 意义:计算张量在指定维度的和。
    • 用法示例:correct += (predicted == labels).sum().item()
  15. torch.save

    • 格式:torch.save(obj, f)
    • 参数:要保存的对象、文件路径。
    • 意义:保存对象到文件。
    • 用法示例:torch.save(model.state_dict(), 'resnet.ckpt')

这些函数和类是构建、训练和测试PyTorch模型的基础,涵盖了设备配置、数据加载、模型定义、训练过程、测试评估和模型保存等关键步骤。

运行过程

在整体进行一些可视化改进之后,可以看到效果图如下图所示:
在这里插入图片描述
在这里插入图片描述

在这里插入图片描述
在这里插入图片描述

相关文章:

PyTorch--残差网络(ResNet)在CIFAR-10数据集进行图像分类

完整代码 import torch import torch.nn as nn import torchvision import torchvision.transforms as transforms# Device configuration device torch.device(cuda if torch.cuda.is_available() else cpu)# Hyper-parameters num_epochs 80 batch_size 100 learning_rate…...

ETAS工具链自动化实战指南<一>

----自动化不仅是一种技术,更是一种思维方式,它将帮助我们在快节奏的工作环境中保持领先! 目录 往期推荐 场景一:SWC 之间 port自动连接 命令示例 参数说明 场景二:SWC与ECU 自动映射 命令示例 参数说明 场景三&…...

疫情期间我面试了13家企业软件测试岗位,一些面试题整理

项目的测试流程 拿到需求文档后,写测试用例 审核测试用例 等待开发包 部署测试环境 冒烟测试(网页架构图) 页面初始化测试(查看数据库中的数据内容和页面展示的内容是否一致,并且是否按照某些顺序排列&#xff09…...

PINCE——Linux 原生游戏内存修改器,一款替代 Cheat Engine 的强大游戏修改器,Linux 游戏玩家必备神器!

PINCE——Linux 原生游戏内存修改器,一款替代 Cheat Engine 的强大游戏修改器,Linux 游戏玩家必备神器! PINCE 是 GNU Project Debugger(GDB) 的前端/反向工程工具,常用作程序调试器,主要用于游戏领域,修改…...

为IntelliJ IDEA安装插件

安装插件 插件是开发工具的扩展程序,通常由第三方提供,当安装了插件后,原开发工作的菜单、按钮等开发环境可能会发生变化,例如出现了新的菜单项,或出现了新的按钮,甚至一些全新的编码方式,通常…...

ES6 Promise

ES6 Promise 对象 一、概述 是异步编程的一种解决方案。 从语法上说,Promise 是一个对象,从它可以获取异步操作的消息。 Promise 状态 状态的特点 Promise 异步操作有三种状态:pending(进行中)、fulfilled(…...

html+css 实现hover 凹陷按钮

前言:哈喽,大家好,今天给大家分享html+css 绚丽效果!并提供具体代码帮助大家深入理解,彻底掌握!创作不易,如果能帮助到大家或者给大家一些灵感和启发,欢迎收藏+关注哦 💕 目录 📚一、效果📚二、原理解析💡1.这是一个,hover时凹陷的效果。每个按钮是一个button…...

什么是负载均衡?负载均衡器如何运作?

往期文章 负载均衡器:LVS、Nginx、HAproxy如何选择? 目录 往期文章什么是负载均衡?为什么需要负载均衡?负载均衡工作原理?静态负载均衡算法动态负载均衡算法 参考 什么是负载均衡? 负载均衡是一种网络技术…...

(Arxiv-2023)潜在一致性模型:通过少步推理合成高分辨率图像

潜在一致性模型:通过少步推理合成高分辨率图像 Paper Title: Latent Consistency Models: Synthesizing High-Resolution Images with Few-Step Inference Paper是清华发表在Arxiv 2023的工作 Paper地址 Code地址 ABSTRACT 潜在扩散模型 (LDM) 在合成高分辨率图像方…...

Unity与UE,哪种游戏引擎适合你?

PlayStation vs Xbox,Mario vs Sonic,Unreal vs Unity?无论是游戏主机、角色还是游戏引擎,人们总是热衷于捍卫他们在游戏行业中的偏爱。 专注于游戏引擎,Unity和Unreal Engine(简称UE4)是目前市…...

这五本大模型书籍,把大模型讲的非常详细,收藏我这一篇就够了

当然可以。在当前的大模型时代,随着自然语言处理(NLP)技术的迅速发展,出现了许多优秀的书籍来帮助读者理解这些复杂的技术。以下是几本值得推荐的大模型书籍,它们涵盖了从基础理论到高级实践的内容,可以帮助…...

伊朗通过 ChatGPT 试图影响美国大选, OpenAI 封禁多个账户|TodayAI

OpenAI 近日宣布,他们已经封禁了一系列与伊朗影响行动有关的 ChatGPT 账户,这些账户涉嫌利用该 AI 工具生成并传播与美国总统选举、以色列 – 哈马斯战争以及奥运会等相关的内容。 OpenAI 表示,这些账户与一个名为 “Storm-2035” 的秘密伊朗…...

windows系统如何走后面之windows系统隐藏账户

系统隐藏账户是一种最为简单有效的权限维持方式,其做法就是让攻击者创建一个新的具有管理员权限的隐藏账户,因为是隐藏账户,所以防守方是无法通过控制面板或命令行看到这个账户的。 自然我们需要一些前提条件,比如说有一个网站&am…...

Elasticsearch(ES)(版本7.x)数据更新后刷新策略RefreshPolicy

Elasticsearch(ES)(版本7.x)数据更新后刷新策略RefreshPolicy 介绍 ES数据写入后,默认1s后才会被搜索到(refresh_interval为1); 这样可能是考虑到性能问题,毕竟实时IO 消耗较多资源 造成的问题 例如一个索引现在有…...

【运维】从一个git库迁移到另一个库

工作目录: /home/java/hosts 10.60.100.194 脚本 hosts / hostsShell GitLab (gbcom.com.cn) 核心代码...

and design vue表格列宽度拖拽,vue-draggable-resizable插件使用

and design vue2版的table表格不能拖拽列的宽度,通过vue-draggable-resizable插件实现 我用的是and design 1.7.8的版本,先下插件 yarn add vue-draggable-resizable2.1.0我这版本的and design用最新3.0.0以上的插件会有问题,实现不了效果&a…...

使用hexo搭建个人博客

很早之前使用hexo和github建了个人博客。搭建的流程一直没有梳理,中间换过几次机器,每次都得重新配置一遍,需要重新学些。最近电脑坏了,原始的数据没有导出来,先把以前文章写个文件占个位置,后面慢慢补吧&a…...

java geotool构建地理点线面

哈喽,各位小伙伴们,你们好呀,我是喵手。运营社区:C站/掘金/腾讯云/阿里云/华为云/51CTO;欢迎大家常来逛逛 今天我要给大家分享一些自己日常学习到的一些知识点,并以文字的形式跟大家一起交流,互…...

C# 中 Grpc服务端调用客户端方法

在 gRPC 中,服务端通常不直接调用客户端的方法,因为 gRPC 的设计模型是服务端提供服务,客户端调用服务。通常情况下,服务端和客户端之间是解耦的,服务端只提供服务端点,客户端通过这些端点发起请求。 不过…...

Arthas相关命令

官方网站:命令列表 | arthas 也可以用idea的插件arthas-idea的插件根据你想定位的代码生成命令 jvm 相关 dashboard - 当前系统的实时数据面板getstatic - 查看类的静态属性heapdump - dump java heap, 类似 jmap 命令的 heap dump 功能jvm - 查看当前 JVM 的信息l…...

第19节 Node.js Express 框架

Express 是一个为Node.js设计的web开发框架,它基于nodejs平台。 Express 简介 Express是一个简洁而灵活的node.js Web应用框架, 提供了一系列强大特性帮助你创建各种Web应用,和丰富的HTTP工具。 使用Express可以快速地搭建一个完整功能的网站。 Expre…...

React hook之useRef

React useRef 详解 useRef 是 React 提供的一个 Hook,用于在函数组件中创建可变的引用对象。它在 React 开发中有多种重要用途,下面我将全面详细地介绍它的特性和用法。 基本概念 1. 创建 ref const refContainer useRef(initialValue);initialValu…...

基于ASP.NET+ SQL Server实现(Web)医院信息管理系统

医院信息管理系统 1. 课程设计内容 在 visual studio 2017 平台上,开发一个“医院信息管理系统”Web 程序。 2. 课程设计目的 综合运用 c#.net 知识,在 vs 2017 平台上,进行 ASP.NET 应用程序和简易网站的开发;初步熟悉开发一…...

五年级数学知识边界总结思考-下册

目录 一、背景二、过程1.观察物体小学五年级下册“观察物体”知识点详解:由来、作用与意义**一、知识点核心内容****二、知识点的由来:从生活实践到数学抽象****三、知识的作用:解决实际问题的工具****四、学习的意义:培养核心素养…...

LLM基础1_语言模型如何处理文本

基于GitHub项目:https://github.com/datawhalechina/llms-from-scratch-cn 工具介绍 tiktoken:OpenAI开发的专业"分词器" torch:Facebook开发的强力计算引擎,相当于超级计算器 理解词嵌入:给词语画"…...

DingDing机器人群消息推送

文章目录 1 新建机器人2 API文档说明3 代码编写 1 新建机器人 点击群设置 下滑到群管理的机器人,点击进入 添加机器人 选择自定义Webhook服务 点击添加 设置安全设置,详见说明文档 成功后,记录Webhook 2 API文档说明 点击设置说明 查看自…...

Linux中《基础IO》详细介绍

目录 理解"文件"狭义理解广义理解文件操作的归类认知系统角度文件类别 回顾C文件接口打开文件写文件读文件稍作修改,实现简单cat命令 输出信息到显示器,你有哪些方法stdin & stdout & stderr打开文件的方式 系统⽂件I/O⼀种传递标志位…...

sshd代码修改banner

sshd服务连接之后会收到字符串: SSH-2.0-OpenSSH_9.5 容易被hacker识别此服务为sshd服务。 是否可以通过修改此banner达到让人无法识别此服务的目的呢? 不能。因为这是写的SSH的协议中的。 也就是协议规定了banner必须这么写。 SSH- 开头&#xff0c…...

Vue3 PC端 UI组件库我更推荐Naive UI

一、Vue3生态现状与UI库选择的重要性 随着Vue3的稳定发布和Composition API的广泛采用,前端开发者面临着UI组件库的重新选择。一个好的UI库不仅能提升开发效率,还能确保项目的长期可维护性。本文将对比三大主流Vue3 UI库(Naive UI、Element …...

GAN模式奔溃的探讨论文综述(一)

简介 简介:今天带来一篇关于GAN的,对于模式奔溃的一个探讨的一个问题,帮助大家更好的解决训练中遇到的一个难题。 论文题目:An in-depth review and analysis of mode collapse in GAN 期刊:Machine Learning 链接:...