ccc-pytorch-卷积神经网络实战(6)
文章目录
- 一、CIFAR10 与 lenet5
- 二、CIFAR10 与 ResNet
一、CIFAR10 与 lenet5
第一步:准备数据集
lenet5.py
import torch
from torch.utils.data import DataLoader
from torchvision import datasets
from torchvision import transformsdef main():batchsz = 128CIFAR_train = datasets.CIFAR10('cifar', True, transform=transforms.Compose([transforms.Resize((32, 32)),transforms.ToTensor(),transforms.Normalize(mean=[0.485, 0.456, 0.406],std=[0.229, 0.224, 0.225])]), download=True)cifar_train = DataLoader(CIFAR_train, batch_size=batchsz, shuffle=True)CIFAR_test = datasets.CIFAR10('cifar', False, transform=transforms.Compose([transforms.Resize((32, 32)),transforms.ToTensor(),transforms.Normalize(mean=[0.485, 0.456, 0.406],std=[0.229, 0.224, 0.225])]), download=True)cifar_test = DataLoader(CIFAR_test, batch_size=batchsz, shuffle=True)x,label = iter(cifar_train).next()print('x',x.shape,'label:',label.shape)if __name__ =='__main__':main()
第二步:确认Lenet5网络流程结构
main.py
import torch
from torch import nn
from torch.nn import functional as Fclass Lenet5(nn.Module):def __init__(self):super(Lenet5, self).__init__()self.conv_unit = nn.Sequential(# x: [b, 3, 32, 32] => [b, 6, ]nn.Conv2d(3, 6, kernel_size=5, stride=1, padding=0),nn.AvgPool2d(kernel_size=2, stride=2, padding=0),#nn.Conv2d(6, 16, kernel_size=5, stride=1, padding=0),nn.MaxPool2d(kernel_size=2, stride=2, padding=0),)self.fc_unit = nn.Sequential(nn.Linear(2,120), # 由输出结果反推(拉直打平)nn.ReLU(),nn.Linear(120,84),nn.ReLU(),nn.Linear(84,10))#[b,3,32,32]tmp = torch.randn(2, 3, 32, 32)out = self.conv_unit(tmp)#[2,16,5,5] 由输出结果得到print('conv out:', out.shape)def main():net = Lenet5()if __name__ == '__main__':main()
第三步:完善lenet5 结构并使用GPU加速
lenet5.py
import torch
from torch import nn
from torch.nn import functional as Fclass Lenet5(nn.Module):def __init__(self):super(Lenet5, self).__init__()self.conv_unit = nn.Sequential(# x: [b, 3, 32, 32] => [b, 6, ]nn.Conv2d(3, 6, kernel_size=5, stride=1, padding=0),nn.AvgPool2d(kernel_size=2, stride=2, padding=0),#nn.Conv2d(6, 16, kernel_size=5, stride=1, padding=0),nn.MaxPool2d(kernel_size=2, stride=2, padding=0),)self.fc_unit = nn.Sequential(nn.Linear(16*5*5,120),nn.ReLU(),nn.Linear(120,84),nn.ReLU(),nn.Linear(84,10))#[b,3,32,32]tmp = torch.randn(2, 3, 32, 32)out = self.conv_unit(tmp)#[b,16,5,5]print('conv out:', out.shape)def forward(self,x):batchsz = x.size(0)# [b, 3, 32, 32] => [b, 16, 5, 5]x = self.conv_unit(x)#[b, 16, 5, 5] => [b,16*5*5]x = x.view(batchsz,16*5*5)# [b, 16*5*5] => [b, 10]logits = self.fc_unit(x)pred = F.softmax(logits,dim=1)return logitsdef main():net = Lenet5()tmp = torch.randn(2, 3, 32, 32)out = net(tmp)print('lenet out:', out.shape)if __name__ == '__main__':main()
main.py
import torch
from torch.utils.data import DataLoader
from torchvision import datasets
from torchvision import transforms
from lenet5 import Lenet5
from torch import nn, optimdef main():batchsz = 128CIFAR_train = datasets.CIFAR10('cifar', True, transform=transforms.Compose([transforms.Resize((32, 32)),transforms.ToTensor()]), download=True)cifar_train = DataLoader(CIFAR_train, batch_size=batchsz, shuffle=True)CIFAR_test = datasets.CIFAR10('cifar', False, transform=transforms.Compose([transforms.Resize((32, 32)),transforms.ToTensor()]), download=True)cifar_test = DataLoader(CIFAR_test, batch_size=batchsz, shuffle=True)x,label = iter(cifar_train).next()print('x',x.shape,'label:',label.shape)device = torch.device('cuda')model = Lenet5().to(device)print(model)if __name__ =='__main__':main()
第四步:计算交叉熵和准确率,完成迭代
import torch
from torch.utils.data import DataLoader
from torchvision import datasets
from torchvision import transforms
from lenet5 import Lenet5
from torch import nn, optimdef main():batchsz = 128CIFAR_train = datasets.CIFAR10('cifar', True, transform=transforms.Compose([transforms.Resize((32, 32)),transforms.ToTensor()]), download=True)cifar_train = DataLoader(CIFAR_train, batch_size=batchsz, shuffle=True)CIFAR_test = datasets.CIFAR10('cifar', False, transform=transforms.Compose([transforms.Resize((32, 32)),transforms.ToTensor()]), download=True)cifar_test = DataLoader(CIFAR_test, batch_size=batchsz, shuffle=True)x,label = iter(cifar_train).next()print('x',x.shape,'label:',label.shape)device = torch.device('cuda')model = Lenet5().to(device)criteon = nn.CrossEntropyLoss().to(device)optimizer = optim.Adam(model.parameters(),lr=1e-3)print(model)for epoch in range(1000):for batchidx, (x,label) in enumerate(cifar_train):# [b, 3, 32, 32]# [b]x,label = x.to(device),label.to(device)logits = model(x)# logits: [b, 10]# label: [b]loss = criteon(logits,label)# backpropoptimizer.zero_grad()loss.backward()optimizer.step()print(epoch,'loss:',loss.item())model.eval()with torch.no_grad(): #之后代码不需backproptotal_correct = 0total_num = 0for x ,label in cifar_test:# [b, 3, 32, 32]# [b]x,label = x.to(device),label.to(device)logits = model(x)pred = logits.argmax(dim=1)total_correct += torch.eq(pred,label).float().sum()total_num += x.size(0)acc = total_correct / total_numprint(epoch,acc)if __name__ =='__main__':main()
注意事项:
- 之所以在 测试时 添加 model.eval()是因为eval()时,BN会使用之前计算好的值,并且停止使用DropOut。保证用全部训练的均值和方差
二、CIFAR10 与 ResNet
第一步:构建ResNet18的网络结构
ResNet.py
import torch
from torch import nn
from torch.nn import functional as Fclass ResBlk(nn.Module):def __init__(self,ch_in,ch_out,stride=1):super(ResBlk,self).__init__()self.conv1 = nn.Conv2d(ch_in,ch_out,kernel_size=3,stride=stride,padding=1)self.bn1 = nn.BatchNorm2d(ch_out)self.conv2 = nn.Conv2d(ch_out,ch_out,kernel_size=3,stride=1,padding=1)self.bn2 = nn.BatchNorm2d(ch_out)self.extra = nn.Sequential()if ch_out != ch_in:self.extra = nn.Sequential(nn.Conv2d(ch_in,ch_out,kernel_size=1,stride=stride),nn.BatchNorm2d(ch_out))def forward(self, x):out = F.relu(self.bn1(self.conv1(x)))out = self.bn2(self.conv2(out))#[b, ch_in, h, w] = > [b, ch_out, h, w]out = self.extra(x) + outout = F.relu((out))return outclass ResNet18(nn.Module):def __init__(self):super(ResNet18, self).__init__()self.conv1 = nn.Sequential(nn.Conv2d(3,64,kernel_size=3,stride=3,padding=0),nn.BatchNorm2d(64))# followed 4 blocks# [b, 64, h, w] => [b, 128, h ,w]self.blk1 = ResBlk(64,128)# [b, 128, h, w] => [b, 256, h ,w]self.blk2 = ResBlk(128,256)# [b, 256, h, w] => [b, 512, h ,w]self.blk3 = ResBlk(256,512)# [b, 512, h, w] => [b, 1024, h ,w]self.blk4 = ResBlk(512,512)self.outlayer = nn.Linear(512*1*1,10)def forward(self,x):x = F.relu(self.conv1(x))x = self.blk1(x)x = self.blk2(x)x = self.blk3(x)x = self.blk4(x)print('after conv:', x.shape)# [b, 512, h, w] => [b, 512, 1, 1]x = F.adaptive_avg_pool2d(x, [1, 1])print('after pool:', x.shape)x = x.view(x.size(0), -1)x = self.outlayer(x)return xdef main():blk = ResBlk(64,128,stride=2)tmp = torch.randn(2,64,32,32)out = blk(tmp)print('block:',out.shape)x = torch.randn(2,3,32,32)model = ResNet18()out = model(x)print('resnet:',out.shape)if __name__ == '__main__':main()
第二步:代入第一个项目的main函数中即可
main.py
import torch
from torch.utils.data import DataLoader
from torchvision import datasets
from torchvision import transforms
from resnet import ResNet18
from torch import nn, optimdef main():batchsz = 128CIFAR_train = datasets.CIFAR10('cifar', True, transform=transforms.Compose([transforms.Resize((32, 32)),transforms.ToTensor()]), download=True)cifar_train = DataLoader(CIFAR_train, batch_size=batchsz, shuffle=True)CIFAR_test = datasets.CIFAR10('cifar', False, transform=transforms.Compose([transforms.Resize((32, 32)),transforms.ToTensor()]), download=True)cifar_test = DataLoader(CIFAR_test, batch_size=batchsz, shuffle=True)x,label = iter(cifar_train).next()print('x',x.shape,'label:',label.shape)device = torch.device('cuda')model = ResNet18().to(device)criteon = nn.CrossEntropyLoss().to(device)optimizer = optim.Adam(model.parameters(),lr=1e-3)print(model)for epoch in range(1000):for batchidx, (x,label) in enumerate(cifar_train):# [b, 3, 32, 32]# [b]x,label = x.to(device),label.to(device)logits = model(x)# logits: [b, 10]# label: [b]loss = criteon(logits,label)# backpropoptimizer.zero_grad()loss.backward()optimizer.step()print(epoch,'loss:',loss.item())model.eval()with torch.no_grad(): #之后代码不需backproptotal_correct = 0total_num = 0for x ,label in cifar_test:# [b, 3, 32, 32]# [b]x,label = x.to(device),label.to(device)logits = model(x)pred = logits.argmax(dim=1)total_correct += torch.eq(pred,label).float().sum()total_num += x.size(0)acc = total_correct / total_numprint(epoch,acc)if __name__ =='__main__':main()
网络结构如下:
迭代准确率和交叉熵计算如下:
其他需要注意的地方:
- 并不是ResNet的paper中流程完全相同,但是十分类似
- 可以对数据进行数据增强和归一化等操作进一步提升效果
相关文章:

ccc-pytorch-卷积神经网络实战(6)
文章目录一、CIFAR10 与 lenet5二、CIFAR10 与 ResNet一、CIFAR10 与 lenet5 第一步:准备数据集 lenet5.py import torch from torch.utils.data import DataLoader from torchvision import datasets from torchvision import transformsdef main():batchsz 128C…...

置信椭圆(误差椭圆)详解
文章目录Part.I 预备知识Chap.I 一些概念Chap.II 主成分分析Chap.III Matlab 函数 randnChap.IV Matlab 函数 pcaPart.II 置信椭圆的含义Chap.I 一个 Matlab 实例Sec.I 两个不相关变量的特征Sec.II 两个相关变量的特征Chap.II 变换阵 (解相关矩阵) 的求解ReferencePart.I 预备知…...

FreeSWITCH 智能呼叫流程设计
文章目录1. 智能呼叫流程2. 细节处理1. 呼叫字符串指定拨号计划2. 外呼的拨号计划3. 语音打断的支持1. 智能呼叫流程 用户与机器人对话通常都是以文本的形式进行,但是借助 ASR 和 TTS 技术,以语音电话为载体的智能呼叫系统成为可能。智能呼叫系统涉及到…...
什么是Restful风格
什么是RestFul风格? Restful就是一个资源定位及资源操作的风格。不是标准也不是协议,只是一种风格。基于这个风格设计的软件可以更简洁,更有层次,更易于实现缓存等机制。 REST即Representational State Transfer的缩写࿰…...

sumifs的交叉 表的例子
比如这样,那么冰箱绿山店的栏位中,SUMIFS($D$3:$D$10,$B$3:$B$10,$F3,$C$3:$C$10,G$2)就是把求和范围,条件1设置为固定列的复合引用,条件2设置为固定行的复合引用即可。...

React :一、简单概念
目录 1.什么是React? 2.谁开发的 3.为什么要学React? 4.React的特点? 5.React依赖包 6.第一个React程序 7.虚拟DOM的两种创建方法 8.虚拟DOM和真实DOM 1.什么是React? 用于构建用户界面的JavaScript库,是一个将…...

Actipro WinForms Studio Crack
Actipro WinForms Studio Crack 已验证Microsoft.NET 7兼容性。 添加了MetroDark配色方案。 添加了支持MetroLight和MetroDark颜色方案的MetroScrollBarRenderer。 添加了IWindowsColorScheme接口,该接口将替换对WindowsColorScheme的大多数引用。 添加了IWindowsCo…...

英伦四地到底是什么关系?
英格兰、苏格兰、威尔士和北爱尔兰四地到底是什么关系,为何苏格兰非要独立?故事还要从中世纪说起。大不列颠岛位于欧洲西部,和欧洲大陆隔海相望。在古代,大不列颠岛和爱尔兰属于凯尔特人的领地。凯尔特人是欧洲西部一个庞大的族群…...

Google三大论文之GFS
Google三大论文之GFS Google GFS(Google File System) 文件系统,一个面向大规模数据密集型应用的、可伸缩的分布式文件系统。GFS 虽然运行在廉价的普遍硬件设备上,但是它依然了提供灾难冗余的能力,为大量客户机提供了…...

嵌入式安防监控项目——exynos4412主框架搭建
目录 一、模块化编程思维 二、安防监控项目主框架搭建 一、模块化编程思维 其实我们以前学习32使用keil的时候就是再用模块化的思维。每个硬件都单独有一个实现功能的C文件和声明函数,进行宏定义以及引用需要使用头文件的h文件。 比如简单的加减乘除取余操作我们…...

YOLOv5s网络模型讲解(一看就会)
文章目录前言1、YOLOv5s-6.0组成2、YOLOv5s网络介绍2.1、参数解析2.2、YOLOv5s.yaml2.3、YOLOv5s网络结构图3、附件3.1、yolov5s.yaml 解析表3.2、 yolov5l.yaml 解析表总结前言 最近在重构YOLOv5代码,本章主要介绍YOLOv5s的网络结构 1、YOLOv5s-6.0组成 我们熟知YO…...
kkfileView linux 离线安装
文章目录前言一、安装 LiberOffice二、安装kkfileView1.下载安装包2.启动总结前言 一、安装 LiberOffice 下载https://kkfileview.keking.cn/LibreOffice_7.1.4_Linux_x86-64_rpm.tar.gz 安装 tar -zxvf LibreOffice_7.1.4_Linux_x86-64_rpm.tar.gz cd LibreOffice_7.1.4.2_L…...

如何编写BI项目之ETL文档
XXXXBI项目之ETL文档 xxx项目组 ------------------------------------------------1---------------------------------------------------------------------- 目录 一 、ETL之概述 1、ETL是数据仓库建构/应用中的核心…...

【LeetCode】剑指 Offer 24. 反转链表 p142 -- Java Version
题目链接:https://leetcode.cn/problems/fan-zhuan-lian-biao-lcof/submissions/ 1. 题目介绍(24. 反转链表) 定义一个函数,输入一个链表的头节点,反转该链表并输出反转后链表的头节点。 【测试用例】: 示…...

LAY-EXCEL导出excel并实现单元格合并
通过lay-excel插件实现Excel导出,并实现单元格合并,样式设置等功能。更详细描述,请去lay-excel插件文档查看,地址:http://excel.wj2015.com/_book/docs/%E5%BF%AB%E9%80%9F%E4%B8%8A%E6%89%8B.html一、安装这里使用Vue…...

配置VM虚拟机Centos7网络
配置VM虚拟机Centos7网络 第一步,进入虚拟机设置选中【网络适配器】选择【NAT模式】 第二步,进入windows【控制面板\网络和 Internet\网络连接】设置网络状态。 我们选择【VMnet8】 点击【属性】查看它的网络配置 2 .我们找到【Internet 协议版本 4(TCP…...

Kafka 位移主题
Kafka 位移主题位移格式创建位移提交位移删除位移Kafka 的内部主题 (Internal Topic) : __consumer_offsets (位移主题,Offsets Topic) 老 Consumer 会将位移消息提交到 ZK 中保存 当 Consumer 重启后,能自动从 ZK 中读取位移数据,继续消费…...

详细讲解零拷贝机制的进化过程
一、传统拷贝方式(一)操作系统经过4次拷贝CPU 负责将数据从磁盘搬运到内核空间的 Page Cache 中;CPU 负责将数据从内核空间的 Page Cache 搬运到用户空间的缓冲区;CPU 负责将数据从用户空间的缓冲区搬运到内核空间的 Socket 缓冲区…...

2023年场外个股期权研究报告
第一章 概况 场外个股期权(Over-the-Counter Equity Option),是指由交易双方根据自己的需求和意愿,通过协商确定行权价格、行权日期等条款的股票期权。与交易所交易的标准化期权不同,场外个股期权的合同内容可以根据交…...
k8s pod,ns,pvc 强制删除
一、强制删除pod$ kubectl delete pod <your-pod-name> -n <name-space> --force --grace-period0解决方法:加参数 --force --grace-period0,grace-period表示过渡存活期,默认30s,在删除POD之前允许POD慢慢终止其上的…...
【杂谈】-递归进化:人工智能的自我改进与监管挑战
递归进化:人工智能的自我改进与监管挑战 文章目录 递归进化:人工智能的自我改进与监管挑战1、自我改进型人工智能的崛起2、人工智能如何挑战人类监管?3、确保人工智能受控的策略4、人类在人工智能发展中的角色5、平衡自主性与控制力6、总结与…...
【位运算】消失的两个数字(hard)
消失的两个数字(hard) 题⽬描述:解法(位运算):Java 算法代码:更简便代码 题⽬链接:⾯试题 17.19. 消失的两个数字 题⽬描述: 给定⼀个数组,包含从 1 到 N 所有…...

376. Wiggle Subsequence
376. Wiggle Subsequence 代码 class Solution { public:int wiggleMaxLength(vector<int>& nums) {int n nums.size();int res 1;int prediff 0;int curdiff 0;for(int i 0;i < n-1;i){curdiff nums[i1] - nums[i];if( (prediff > 0 && curdif…...

DingDing机器人群消息推送
文章目录 1 新建机器人2 API文档说明3 代码编写 1 新建机器人 点击群设置 下滑到群管理的机器人,点击进入 添加机器人 选择自定义Webhook服务 点击添加 设置安全设置,详见说明文档 成功后,记录Webhook 2 API文档说明 点击设置说明 查看自…...

Unity UGUI Button事件流程
场景结构 测试代码 public class TestBtn : MonoBehaviour {void Start(){var btn GetComponent<Button>();btn.onClick.AddListener(OnClick);}private void OnClick(){Debug.Log("666");}}当添加事件时 // 实例化一个ButtonClickedEvent的事件 [Formerl…...

高考志愿填报管理系统---开发介绍
高考志愿填报管理系统是一款专为教育机构、学校和教师设计的学生信息管理和志愿填报辅助平台。系统基于Django框架开发,采用现代化的Web技术,为教育工作者提供高效、安全、便捷的学生管理解决方案。 ## 📋 系统概述 ### 🎯 系统定…...
命令行关闭Windows防火墙
命令行关闭Windows防火墙 引言一、防火墙:被低估的"智能安检员"二、优先尝试!90%问题无需关闭防火墙方案1:程序白名单(解决软件误拦截)方案2:开放特定端口(解决网游/开发端口不通)三、命令行极速关闭方案方法一:PowerShell(推荐Win10/11)方法二:CMD命令…...
Yii2项目自动向GitLab上报Bug
Yii2 项目自动上报Bug 原理 yii2在程序报错时, 会执行指定action, 通过重写ErrorAction, 实现Bug自动提交至GitLab的issue 步骤 配置SiteController中的actions方法 public function actions(){return [error > [class > app\helpers\web\ErrorAction,],];}重写Error…...

Spring AI中使用ChatMemory实现会话记忆功能
文章目录 1、需求2、ChatMemory中消息的存储位置3、实现步骤1、引入依赖2、配置Spring AI3、配置chatmemory4、java层传递conversaionId 4、验证5、完整代码6、参考文档 1、需求 我们知道大型语言模型 (LLM) 是无状态的,这就意味着他们不会保…...

2025-06-08-深度学习网络介绍(语义分割,实例分割,目标检测)
深度学习网络介绍(语义分割,实例分割,目标检测) 前言 在开始这篇文章之前,我们得首先弄明白,什么是图像分割? 我们知道一个图像只不过是许多像素的集合。图像分割分类是对图像中属于特定类别的像素进行分类的过程,即像素级别的…...