当前位置: 首页 > news >正文

ccc-pytorch-卷积神经网络实战(6)

文章目录

      • 一、CIFAR10 与 lenet5
      • 二、CIFAR10 与 ResNet

一、CIFAR10 与 lenet5

image-20230305193723321
第一步:准备数据集
lenet5.py

import torch
from torch.utils.data import DataLoader
from torchvision import datasets
from torchvision import transformsdef main():batchsz = 128CIFAR_train = datasets.CIFAR10('cifar', True, transform=transforms.Compose([transforms.Resize((32, 32)),transforms.ToTensor(),transforms.Normalize(mean=[0.485, 0.456, 0.406],std=[0.229, 0.224, 0.225])]), download=True)cifar_train = DataLoader(CIFAR_train, batch_size=batchsz, shuffle=True)CIFAR_test = datasets.CIFAR10('cifar', False, transform=transforms.Compose([transforms.Resize((32, 32)),transforms.ToTensor(),transforms.Normalize(mean=[0.485, 0.456, 0.406],std=[0.229, 0.224, 0.225])]), download=True)cifar_test = DataLoader(CIFAR_test, batch_size=batchsz, shuffle=True)x,label = iter(cifar_train).next()print('x',x.shape,'label:',label.shape)if __name__ =='__main__':main()

image-20230306214402277

第二步:确认Lenet5网络流程结构
main.py

import torch
from torch import nn
from torch.nn import functional as Fclass Lenet5(nn.Module):def __init__(self):super(Lenet5, self).__init__()self.conv_unit = nn.Sequential(# x: [b, 3, 32, 32] => [b, 6, ]nn.Conv2d(3, 6, kernel_size=5, stride=1, padding=0),nn.AvgPool2d(kernel_size=2, stride=2, padding=0),#nn.Conv2d(6, 16, kernel_size=5, stride=1, padding=0),nn.MaxPool2d(kernel_size=2, stride=2, padding=0),)self.fc_unit = nn.Sequential(nn.Linear(2,120), # 由输出结果反推(拉直打平)nn.ReLU(),nn.Linear(120,84),nn.ReLU(),nn.Linear(84,10))#[b,3,32,32]tmp = torch.randn(2, 3, 32, 32)out = self.conv_unit(tmp)#[2,16,5,5]   由输出结果得到print('conv out:', out.shape)def main():net = Lenet5()if __name__ == '__main__':main()


第三步:完善lenet5 结构并使用GPU加速
lenet5.py

import torch
from torch import nn
from torch.nn import functional as Fclass Lenet5(nn.Module):def __init__(self):super(Lenet5, self).__init__()self.conv_unit = nn.Sequential(# x: [b, 3, 32, 32] => [b, 6, ]nn.Conv2d(3, 6, kernel_size=5, stride=1, padding=0),nn.AvgPool2d(kernel_size=2, stride=2, padding=0),#nn.Conv2d(6, 16, kernel_size=5, stride=1, padding=0),nn.MaxPool2d(kernel_size=2, stride=2, padding=0),)self.fc_unit = nn.Sequential(nn.Linear(16*5*5,120),nn.ReLU(),nn.Linear(120,84),nn.ReLU(),nn.Linear(84,10))#[b,3,32,32]tmp = torch.randn(2, 3, 32, 32)out = self.conv_unit(tmp)#[b,16,5,5]print('conv out:', out.shape)def forward(self,x):batchsz = x.size(0)# [b, 3, 32, 32] => [b, 16, 5, 5]x = self.conv_unit(x)#[b, 16, 5, 5] => [b,16*5*5]x = x.view(batchsz,16*5*5)# [b, 16*5*5] => [b, 10]logits = self.fc_unit(x)pred = F.softmax(logits,dim=1)return logitsdef main():net = Lenet5()tmp = torch.randn(2, 3, 32, 32)out = net(tmp)print('lenet out:', out.shape)if __name__ == '__main__':main()

main.py

import torch
from torch.utils.data import DataLoader
from torchvision import datasets
from torchvision import transforms
from lenet5 import Lenet5
from    torch import nn, optimdef main():batchsz = 128CIFAR_train = datasets.CIFAR10('cifar', True, transform=transforms.Compose([transforms.Resize((32, 32)),transforms.ToTensor()]), download=True)cifar_train = DataLoader(CIFAR_train, batch_size=batchsz, shuffle=True)CIFAR_test = datasets.CIFAR10('cifar', False, transform=transforms.Compose([transforms.Resize((32, 32)),transforms.ToTensor()]), download=True)cifar_test = DataLoader(CIFAR_test, batch_size=batchsz, shuffle=True)x,label = iter(cifar_train).next()print('x',x.shape,'label:',label.shape)device = torch.device('cuda')model = Lenet5().to(device)print(model)if __name__ =='__main__':main()

image-20230307210738450
第四步:计算交叉熵和准确率,完成迭代

import torch
from torch.utils.data import DataLoader
from torchvision import datasets
from torchvision import transforms
from lenet5 import Lenet5
from    torch import nn, optimdef main():batchsz = 128CIFAR_train = datasets.CIFAR10('cifar', True, transform=transforms.Compose([transforms.Resize((32, 32)),transforms.ToTensor()]), download=True)cifar_train = DataLoader(CIFAR_train, batch_size=batchsz, shuffle=True)CIFAR_test = datasets.CIFAR10('cifar', False, transform=transforms.Compose([transforms.Resize((32, 32)),transforms.ToTensor()]), download=True)cifar_test = DataLoader(CIFAR_test, batch_size=batchsz, shuffle=True)x,label = iter(cifar_train).next()print('x',x.shape,'label:',label.shape)device = torch.device('cuda')model = Lenet5().to(device)criteon = nn.CrossEntropyLoss().to(device)optimizer = optim.Adam(model.parameters(),lr=1e-3)print(model)for epoch in range(1000):for batchidx, (x,label) in enumerate(cifar_train):# [b, 3, 32, 32]# [b]x,label = x.to(device),label.to(device)logits = model(x)# logits: [b, 10]# label:  [b]loss = criteon(logits,label)# backpropoptimizer.zero_grad()loss.backward()optimizer.step()print(epoch,'loss:',loss.item())model.eval()with torch.no_grad(): #之后代码不需backproptotal_correct = 0total_num = 0for x ,label in cifar_test:# [b, 3, 32, 32]# [b]x,label = x.to(device),label.to(device)logits = model(x)pred = logits.argmax(dim=1)total_correct += torch.eq(pred,label).float().sum()total_num += x.size(0)acc = total_correct / total_numprint(epoch,acc)if __name__ =='__main__':main()

image-20230307212056887
注意事项:

  • 之所以在 测试时 添加 model.eval()是因为eval()时,BN会使用之前计算好的值,并且停止使用DropOut。保证用全部训练的均值和方差

二、CIFAR10 与 ResNet

img
第一步:构建ResNet18的网络结构
ResNet.py

import torch
from torch import  nn
from torch.nn import functional as Fclass ResBlk(nn.Module):def __init__(self,ch_in,ch_out,stride=1):super(ResBlk,self).__init__()self.conv1 = nn.Conv2d(ch_in,ch_out,kernel_size=3,stride=stride,padding=1)self.bn1 = nn.BatchNorm2d(ch_out)self.conv2 = nn.Conv2d(ch_out,ch_out,kernel_size=3,stride=1,padding=1)self.bn2 = nn.BatchNorm2d(ch_out)self.extra = nn.Sequential()if ch_out != ch_in:self.extra = nn.Sequential(nn.Conv2d(ch_in,ch_out,kernel_size=1,stride=stride),nn.BatchNorm2d(ch_out))def forward(self, x):out = F.relu(self.bn1(self.conv1(x)))out = self.bn2(self.conv2(out))#[b, ch_in, h, w] = > [b, ch_out, h, w]out = self.extra(x) + outout = F.relu((out))return outclass ResNet18(nn.Module):def __init__(self):super(ResNet18, self).__init__()self.conv1 = nn.Sequential(nn.Conv2d(3,64,kernel_size=3,stride=3,padding=0),nn.BatchNorm2d(64))# followed 4 blocks# [b, 64, h, w] => [b, 128, h ,w]self.blk1 = ResBlk(64,128)# [b, 128, h, w] => [b, 256, h ,w]self.blk2 = ResBlk(128,256)# [b, 256, h, w] => [b, 512, h ,w]self.blk3 = ResBlk(256,512)# [b, 512, h, w] => [b, 1024, h ,w]self.blk4 = ResBlk(512,512)self.outlayer = nn.Linear(512*1*1,10)def forward(self,x):x = F.relu(self.conv1(x))x = self.blk1(x)x = self.blk2(x)x = self.blk3(x)x = self.blk4(x)print('after conv:', x.shape)# [b, 512, h, w] => [b, 512, 1, 1]x = F.adaptive_avg_pool2d(x, [1, 1])print('after pool:', x.shape)x = x.view(x.size(0), -1)x = self.outlayer(x)return xdef main():blk = ResBlk(64,128,stride=2)tmp = torch.randn(2,64,32,32)out = blk(tmp)print('block:',out.shape)x = torch.randn(2,3,32,32)model  = ResNet18()out = model(x)print('resnet:',out.shape)if __name__ == '__main__':main()

第二步:代入第一个项目的main函数中即可
main.py

import torch
from torch.utils.data import DataLoader
from torchvision import datasets
from torchvision import transforms
from resnet import ResNet18
from    torch import nn, optimdef main():batchsz = 128CIFAR_train = datasets.CIFAR10('cifar', True, transform=transforms.Compose([transforms.Resize((32, 32)),transforms.ToTensor()]), download=True)cifar_train = DataLoader(CIFAR_train, batch_size=batchsz, shuffle=True)CIFAR_test = datasets.CIFAR10('cifar', False, transform=transforms.Compose([transforms.Resize((32, 32)),transforms.ToTensor()]), download=True)cifar_test = DataLoader(CIFAR_test, batch_size=batchsz, shuffle=True)x,label = iter(cifar_train).next()print('x',x.shape,'label:',label.shape)device = torch.device('cuda')model = ResNet18().to(device)criteon = nn.CrossEntropyLoss().to(device)optimizer = optim.Adam(model.parameters(),lr=1e-3)print(model)for epoch in range(1000):for batchidx, (x,label) in enumerate(cifar_train):# [b, 3, 32, 32]# [b]x,label = x.to(device),label.to(device)logits = model(x)# logits: [b, 10]# label:  [b]loss = criteon(logits,label)# backpropoptimizer.zero_grad()loss.backward()optimizer.step()print(epoch,'loss:',loss.item())model.eval()with torch.no_grad(): #之后代码不需backproptotal_correct = 0total_num = 0for x ,label in cifar_test:# [b, 3, 32, 32]# [b]x,label = x.to(device),label.to(device)logits = model(x)pred = logits.argmax(dim=1)total_correct += torch.eq(pred,label).float().sum()total_num += x.size(0)acc = total_correct / total_numprint(epoch,acc)if __name__ =='__main__':main()

网络结构如下:
image-20230308192349803
迭代准确率和交叉熵计算如下:
image-20230308193023625
其他需要注意的地方:

  • 并不是ResNet的paper中流程完全相同,但是十分类似
  • 可以对数据进行数据增强和归一化等操作进一步提升效果

相关文章:

ccc-pytorch-卷积神经网络实战(6)

文章目录一、CIFAR10 与 lenet5二、CIFAR10 与 ResNet一、CIFAR10 与 lenet5 第一步:准备数据集 lenet5.py import torch from torch.utils.data import DataLoader from torchvision import datasets from torchvision import transformsdef main():batchsz 128C…...

置信椭圆(误差椭圆)详解

文章目录Part.I 预备知识Chap.I 一些概念Chap.II 主成分分析Chap.III Matlab 函数 randnChap.IV Matlab 函数 pcaPart.II 置信椭圆的含义Chap.I 一个 Matlab 实例Sec.I 两个不相关变量的特征Sec.II 两个相关变量的特征Chap.II 变换阵 (解相关矩阵) 的求解ReferencePart.I 预备知…...

FreeSWITCH 智能呼叫流程设计

文章目录1. 智能呼叫流程2. 细节处理1. 呼叫字符串指定拨号计划2. 外呼的拨号计划3. 语音打断的支持1. 智能呼叫流程 用户与机器人对话通常都是以文本的形式进行,但是借助 ASR 和 TTS 技术,以语音电话为载体的智能呼叫系统成为可能。智能呼叫系统涉及到…...

什么是Restful风格

什么是RestFul风格? Restful就是一个资源定位及资源操作的风格。不是标准也不是协议,只是一种风格。基于这个风格设计的软件可以更简洁,更有层次,更易于实现缓存等机制。 REST即Representational State Transfer的缩写&#xff0…...

sumifs的交叉 表的例子

比如这样,那么冰箱绿山店的栏位中,SUMIFS($D$3:$D$10,$B$3:$B$10,$F3,$C$3:$C$10,G$2)就是把求和范围,条件1设置为固定列的复合引用,条件2设置为固定行的复合引用即可。...

React :一、简单概念

目录 1.什么是React? 2.谁开发的 3.为什么要学React? 4.React的特点? 5.React依赖包 6.第一个React程序 7.虚拟DOM的两种创建方法 8.虚拟DOM和真实DOM 1.什么是React? 用于构建用户界面的JavaScript库,是一个将…...

Actipro WinForms Studio Crack

Actipro WinForms Studio Crack 已验证Microsoft.NET 7兼容性。 添加了MetroDark配色方案。 添加了支持MetroLight和MetroDark颜色方案的MetroScrollBarRenderer。 添加了IWindowsColorScheme接口,该接口将替换对WindowsColorScheme的大多数引用。 添加了IWindowsCo…...

英伦四地到底是什么关系?

英格兰、苏格兰、威尔士和北爱尔兰四地到底是什么关系,为何苏格兰非要独立?故事还要从中世纪说起。大不列颠岛位于欧洲西部,和欧洲大陆隔海相望。在古代,大不列颠岛和爱尔兰属于凯尔特人的领地。凯尔特人是欧洲西部一个庞大的族群…...

Google三大论文之GFS

Google三大论文之GFS Google GFS(Google File System) 文件系统,一个面向大规模数据密集型应用的、可伸缩的分布式文件系统。GFS 虽然运行在廉价的普遍硬件设备上,但是它依然了提供灾难冗余的能力,为大量客户机提供了…...

嵌入式安防监控项目——exynos4412主框架搭建

目录 一、模块化编程思维 二、安防监控项目主框架搭建 一、模块化编程思维 其实我们以前学习32使用keil的时候就是再用模块化的思维。每个硬件都单独有一个实现功能的C文件和声明函数,进行宏定义以及引用需要使用头文件的h文件。 比如简单的加减乘除取余操作我们…...

YOLOv5s网络模型讲解(一看就会)

文章目录前言1、YOLOv5s-6.0组成2、YOLOv5s网络介绍2.1、参数解析2.2、YOLOv5s.yaml2.3、YOLOv5s网络结构图3、附件3.1、yolov5s.yaml 解析表3.2、 yolov5l.yaml 解析表总结前言 最近在重构YOLOv5代码,本章主要介绍YOLOv5s的网络结构 1、YOLOv5s-6.0组成 我们熟知YO…...

kkfileView linux 离线安装

文章目录前言一、安装 LiberOffice二、安装kkfileView1.下载安装包2.启动总结前言 一、安装 LiberOffice 下载https://kkfileview.keking.cn/LibreOffice_7.1.4_Linux_x86-64_rpm.tar.gz 安装 tar -zxvf LibreOffice_7.1.4_Linux_x86-64_rpm.tar.gz cd LibreOffice_7.1.4.2_L…...

如何编写BI项目之ETL文档

XXXXBI项目之ETL文档 xxx项目组 ------------------------------------------------1---------------------------------------------------------------------- 目录 一 、ETL之概述 1、ETL是数据仓库建构/应用中的核心…...

【LeetCode】剑指 Offer 24. 反转链表 p142 -- Java Version

题目链接:https://leetcode.cn/problems/fan-zhuan-lian-biao-lcof/submissions/ 1. 题目介绍(24. 反转链表) 定义一个函数,输入一个链表的头节点,反转该链表并输出反转后链表的头节点。 【测试用例】: 示…...

LAY-EXCEL导出excel并实现单元格合并

通过lay-excel插件实现Excel导出,并实现单元格合并,样式设置等功能。更详细描述,请去lay-excel插件文档查看,地址:http://excel.wj2015.com/_book/docs/%E5%BF%AB%E9%80%9F%E4%B8%8A%E6%89%8B.html一、安装这里使用Vue…...

配置VM虚拟机Centos7网络

配置VM虚拟机Centos7网络 第一步,进入虚拟机设置选中【网络适配器】选择【NAT模式】 第二步,进入windows【控制面板\网络和 Internet\网络连接】设置网络状态。 我们选择【VMnet8】 点击【属性】查看它的网络配置 2 .我们找到【Internet 协议版本 4(TCP…...

Kafka 位移主题

Kafka 位移主题位移格式创建位移提交位移删除位移Kafka 的内部主题 (Internal Topic) : __consumer_offsets (位移主题,Offsets Topic) 老 Consumer 会将位移消息提交到 ZK 中保存 当 Consumer 重启后,能自动从 ZK 中读取位移数据,继续消费…...

详细讲解零拷贝机制的进化过程

一、传统拷贝方式(一)操作系统经过4次拷贝CPU 负责将数据从磁盘搬运到内核空间的 Page Cache 中;CPU 负责将数据从内核空间的 Page Cache 搬运到用户空间的缓冲区;CPU 负责将数据从用户空间的缓冲区搬运到内核空间的 Socket 缓冲区…...

2023年场外个股期权研究报告

第一章 概况 场外个股期权(Over-the-Counter Equity Option),是指由交易双方根据自己的需求和意愿,通过协商确定行权价格、行权日期等条款的股票期权。与交易所交易的标准化期权不同,场外个股期权的合同内容可以根据交…...

k8s pod,ns,pvc 强制删除

一、强制删除pod$ kubectl delete pod <your-pod-name> -n <name-space> --force --grace-period0解决方法&#xff1a;加参数 --force --grace-period0&#xff0c;grace-period表示过渡存活期&#xff0c;默认30s&#xff0c;在删除POD之前允许POD慢慢终止其上的…...

uniapp 对接腾讯云IM群组成员管理(增删改查)

UniApp 实战&#xff1a;腾讯云IM群组成员管理&#xff08;增删改查&#xff09; 一、前言 在社交类App开发中&#xff0c;群组成员管理是核心功能之一。本文将基于UniApp框架&#xff0c;结合腾讯云IM SDK&#xff0c;详细讲解如何实现群组成员的增删改查全流程。 权限校验…...

【大模型RAG】拍照搜题技术架构速览:三层管道、两级检索、兜底大模型

摘要 拍照搜题系统采用“三层管道&#xff08;多模态 OCR → 语义检索 → 答案渲染&#xff09;、两级检索&#xff08;倒排 BM25 向量 HNSW&#xff09;并以大语言模型兜底”的整体框架&#xff1a; 多模态 OCR 层 将题目图片经过超分、去噪、倾斜校正后&#xff0c;分别用…...

业务系统对接大模型的基础方案:架构设计与关键步骤

业务系统对接大模型&#xff1a;架构设计与关键步骤 在当今数字化转型的浪潮中&#xff0c;大语言模型&#xff08;LLM&#xff09;已成为企业提升业务效率和创新能力的关键技术之一。将大模型集成到业务系统中&#xff0c;不仅可以优化用户体验&#xff0c;还能为业务决策提供…...

Qt/C++开发监控GB28181系统/取流协议/同时支持udp/tcp被动/tcp主动

一、前言说明 在2011版本的gb28181协议中&#xff0c;拉取视频流只要求udp方式&#xff0c;从2016开始要求新增支持tcp被动和tcp主动两种方式&#xff0c;udp理论上会丢包的&#xff0c;所以实际使用过程可能会出现画面花屏的情况&#xff0c;而tcp肯定不丢包&#xff0c;起码…...

Appium+python自动化(十六)- ADB命令

简介 Android 调试桥(adb)是多种用途的工具&#xff0c;该工具可以帮助你你管理设备或模拟器 的状态。 adb ( Android Debug Bridge)是一个通用命令行工具&#xff0c;其允许您与模拟器实例或连接的 Android 设备进行通信。它可为各种设备操作提供便利&#xff0c;如安装和调试…...

阿里云ACP云计算备考笔记 (5)——弹性伸缩

目录 第一章 概述 第二章 弹性伸缩简介 1、弹性伸缩 2、垂直伸缩 3、优势 4、应用场景 ① 无规律的业务量波动 ② 有规律的业务量波动 ③ 无明显业务量波动 ④ 混合型业务 ⑤ 消息通知 ⑥ 生命周期挂钩 ⑦ 自定义方式 ⑧ 滚的升级 5、使用限制 第三章 主要定义 …...

相机Camera日志实例分析之二:相机Camx【专业模式开启直方图拍照】单帧流程日志详解

【关注我&#xff0c;后续持续新增专题博文&#xff0c;谢谢&#xff01;&#xff01;&#xff01;】 上一篇我们讲了&#xff1a; 这一篇我们开始讲&#xff1a; 目录 一、场景操作步骤 二、日志基础关键字分级如下 三、场景日志如下&#xff1a; 一、场景操作步骤 操作步…...

DAY 47

三、通道注意力 3.1 通道注意力的定义 # 新增&#xff1a;通道注意力模块&#xff08;SE模块&#xff09; class ChannelAttention(nn.Module):"""通道注意力模块(Squeeze-and-Excitation)"""def __init__(self, in_channels, reduction_rat…...

RNN避坑指南:从数学推导到LSTM/GRU工业级部署实战流程

本文较长&#xff0c;建议点赞收藏&#xff0c;以免遗失。更多AI大模型应用开发学习视频及资料&#xff0c;尽在聚客AI学院。 本文全面剖析RNN核心原理&#xff0c;深入讲解梯度消失/爆炸问题&#xff0c;并通过LSTM/GRU结构实现解决方案&#xff0c;提供时间序列预测和文本生成…...

在web-view 加载的本地及远程HTML中调用uniapp的API及网页和vue页面是如何通讯的?

uni-app 中 Web-view 与 Vue 页面的通讯机制详解 一、Web-view 简介 Web-view 是 uni-app 提供的一个重要组件&#xff0c;用于在原生应用中加载 HTML 页面&#xff1a; 支持加载本地 HTML 文件支持加载远程 HTML 页面实现 Web 与原生的双向通讯可用于嵌入第三方网页或 H5 应…...