当前位置: 首页 > article >正文

动手学深度学习:CNN和LeNet

前言

该篇文章记述从零如何实现CNN,以及LeNet对于之前数据集分类的提升效果。

从零实现卷积核

import torch
def conv2d(X,k):h,w=k.shapeY=torch.zeros((X.shape[0]-h+1,X.shape[1]-w+1))for i in range(Y.shape[0]):for j in range(Y.shape[1]):Y[i,j]=(X[i:i+h,j:j+w]*k).sum()return Y
X=torch.tensor([[0.,1.,2.],[3.,4.,5.],[6.,7.,8.]])
k=torch.tensor([[0.,1.],[2.,3.]])
conv2d(X,k)

在这里插入图片描述

卷积层

from torch import nn
class Conv2D(nn.Module):def __init__(self,kernel_size):super.__init__()self.weight=nn.Parameter(torch.rand(kernel_size))self.bias=nn.Parameter(torch.zeros(1))def forward(self,x):return conv2d(x,self.weight)+self.bias

验证卷积层对于图像的检测作用

x=torch.ones((6,8))
x[:,2:6]=0
x

在这里插入图片描述

k=torch.tensor([[1.0,-1.0]])
y=conv2d(x,k)
y

在这里插入图片描述
很明显这个卷积核提取到了垂直上的特征

conv2d(x.t(),k)

在这里插入图片描述
并没有学习到水平特征

学习卷积核

我们可以让卷积核自己学习里面的参数以达到对不同图像提取的作用

conv2d=nn.Conv2d(1,1,kernel_size=(1,2),bias=False)x=x.reshape((1,1,x.shape[0],x.shape[1]))
y=y.reshape((1,1,6,7))
lr=3e-2for i in range(10):y_hat=conv2d(x)l=(y_hat-y)**2conv2d.zero_grad()l.sum().backward()conv2d.weight.data[:]-=lr*conv2d.weight.gradprint(f"第{i}轮,loss为{l.sum()}")

在这里插入图片描述

conv2d.weight.data

在这里插入图片描述

填充

def comp_conv2d(conv2d,x):#(1,1)添加batch大小和通道数x=x.reshape((1,1)+x.shape)y=conv2d(x)return y.reshape(y.shape[2:])
conv2d=nn.Conv2d(1,1,kernel_size=3,padding=1)
x=torch.rand(size=(8,8))
comp_conv2d(conv2d,x).shape

在这里插入图片描述

conv2d=nn.Conv2d(1,1,kernel_size=(5,3),padding=(2,1))
x=torch.rand(size=(8,8))
comp_conv2d(conv2d,x).shape

在这里插入图片描述

步幅

conv2d=nn.Conv2d(1,1,kernel_size=(3,3),padding=1,stride=2)
x=torch.rand(size=(8,8))
comp_conv2d(conv2d,x).shape

在这里插入图片描述

多通道

from d2l import torch as d2l
def corr2d_multi_in(X,K):return sum(d2l.corr2d(x,k) for x,k in zip(X,K))
x=torch.randn(size=(4,2,3))
k=torch.randn(size=(4,1,3))
corr2d_multi_in(x,k)

在这里插入图片描述

多输出通道

def corr2d_multi_in_out(X,K):return torch.stack([corr2d_multi_in(X,k)for k in K],0)
K=torch.stack((k,k+1,k+2),0)
K.shape

在这里插入图片描述

corr2d_multi_in_out(x,K)

在这里插入图片描述

1x1卷积

def corr2d_multi_in_out_1x1(X,K):c_i,h,w=X.shapec_o=K.shape[0]X=X.reshape((c_i,h*w))K=K.reshape((c_o,c_i))Y=torch.matmul(K,X)return Y.reshape((c_o,h,w))
X=torch.normal(0,1,(3,3,3))
K=torch.normal(0,1,(2,3,1,1))
Y1=corr2d_multi_in_out_1x1(X,K)
Y2=corr2d_multi_in_out(X,K)
Y1==Y2

在这里插入图片描述

汇聚层

def pool2d(x,pool_size,mode='max'):p_h,p_w=pool_sizeY=torch.zeros((X.shape[0]-p_h+1,X.shape[1]-p_w+1))for i in range(Y.shape[0]):for j in range(Y.shape[1]):if mode=='max':Y[i,j]=X[i:i+p_h,j:j+p_w].max()elif mode=='avg':Y[i,j]=X[i:i+p_h,j:j+p_w].mean()return Y
X=torch.tensor([[0.0,1.,2.],[3.,4.,5.],[6.,7.,8.]])
pool2d(X,(2,2))

在这里插入图片描述

pool2d(X,(2,2),'avg')

在这里插入图片描述

LeNet

这是最早的神经网络,根据我的测试,这个模型在我的数据集上的效果比MLP要提高了1%以上,在这段时间里面,我页发现了原有数据集在分类上存在问题,所以重新制作了一份,在这份数据集上,随着我数据量的提升以及模型的修改,准确率达到了99.7%,且无过拟合现象。

原始的LeNet

from torch import nn
net=nn.Sequential(
nn.Conv2d(1,6,kernel_size=5,padding=2),nn.Sigmoid(),
nn.AvgPool2d(kernel_size=2,stride=2),
nn.Conv2d(6,16,kernel_size=5),nn.Sigmoid(),
nn.AvgPool2d(kernel_size=2,stride=2),
nn.Flatten(),
nn.Linear(16*3*5,120),nn.Sigmoid(),
nn.Linear(120,84),nn.Sigmoid(),
nn.Linear(84,9))
def init_weight(m):if type(m)==nn.Linear or type(m)==nn.Conv2d:nn.init.xavier_uniform_(m.weight)
net.apply(init_weight)

在这里插入图片描述

测试结果

我忘记截图了,效果达到了99%以上,同样的数据集在MLP上是98%

改进后的LeNet

第一版

我将平均池化层改成了最大池化层

net=nn.Sequential(
nn.Conv2d(1,6,kernel_size=5,padding=2),nn.Sigmoid(),
nn.MaxPool2d(kernel_size=2,stride=2),
nn.Conv2d(6,16,kernel_size=5),nn.Sigmoid(),
nn.MaxPool2d(kernel_size=2,stride=2),
nn.Flatten(),
nn.Linear(16*3*5,120),nn.Sigmoid(),
nn.Linear(120,84),nn.Sigmoid(),
nn.Linear(84,9))
def init_weight(m):if type(m)==nn.Linear or type(m)==nn.Conv2d:nn.init.xavier_uniform_(m.weight)
net.apply(init_weight)

在这里插入图片描述

训练修改

我在训练过程中添加了记录test的loss最低时,保存pt和onnx,用于后续推理。

epochs_num=100
train_len=len(train_iter.dataset)
all_acc=[]
all_loss=[]
test_all_acc=[]
shape=None
for epoch in range(epochs_num):acc=0loss=0for x,y in train_iter:hat_y=net(x)l=loss_fn(hat_y,y)loss+=loptimer.zero_grad()l.backward()optimer.step()acc+=(hat_y.argmax(1)==y).sum()all_acc.append(acc/train_len)all_loss.append(loss.detach().numpy())test_acc=0test_loss=0test_len=len(test_iter.dataset)with torch.no_grad():for x,y in test_iter:shape=x.shapehat_y=net(x)test_loss+=loss_fn(hat_y,y)test_acc+=(hat_y.argmax(1)==y).sum()test_all_acc.append(test_acc/test_len)print(f'{epoch}的test的acc{test_acc/test_len}')# 保存测试损失最小的模型if test_loss < best_test_loss:best_test_loss = test_losstorch.save(net, best_model_path)dummy_input = torch.randn(shape)  torch.onnx.export(net, dummy_input, "./models/LeNet5.onnx", opset_version=11)print(f'Saved better model with Test Loss: {best_test_loss:.4f}')

在这里插入图片描述

损失函数可视化

plt.plot(range(1,epochs_num+1),all_loss,'.-',label='train_loss')
plt.text(epochs_num, all_loss[-1], f'{all_loss[-1]:.4f}', fontsize=12, verticalalignment='bottom')

在这里插入图片描述

准确率可视化

plt.plot(range(1,epochs_num+1),all_acc,'-',label='train_acc')
plt.text(epochs_num, all_acc[-1], f'{all_acc[-1]:.4f}', fontsize=12, verticalalignment='bottom')
plt.plot(range(1,epochs_num+1),test_all_acc,'-.',label='test_acc')
plt.legend()

在这里插入图片描述

预测结果

import numpy as np
with torch.no_grad():all_num=5index=1plt.figure(figsize=(12,5))for i,label in zip(test_data_path,test_labels):if index<=all_num:img=cv2.imread(i)input_img=cv2.cvtColor(img,cv2.COLOR_BGR2GRAY)img=cv2.cvtColor(input_img,cv2.COLOR_BGR2RGB)input_img = np.expand_dims(input_img, axis=2)  # 增加通道维度,形状变为 [1, H, W]input_img=transforms.ToTensor()(input_img)input_img = input_img.unsqueeze(0)  # 增加批量维度,形状变为 [1, 1, 28, 20]print(input_img.shape)result=net(input_img).argmax(1)plt.subplot(1,all_num,index)plt.imshow(img)plt.title(f'true{label},predict{result.detach().numpy()}')plt.axis("off")index+=1

在这里插入图片描述

第二版

我将sigmoid激活函数换成了ReLU函数,发现最终的收敛速度极快,损失值也下降了一些,test的acc上升了0.04%,虽然不多,但是训练时间极大减少。

net=nn.Sequential(
nn.Conv2d(1,6,kernel_size=5,padding=2),nn.ReLU(),
nn.MaxPool2d(kernel_size=2,stride=2),
nn.Conv2d(6,16,kernel_size=5),nn.ReLU(),
nn.MaxPool2d(kernel_size=2,stride=2),
nn.Flatten(),
nn.Linear(16*3*5,120),nn.ReLU(),
nn.Linear(120,84),nn.ReLU(),
nn.Linear(84,9))
def init_weight(m):if type(m)==nn.Linear or type(m)==nn.Conv2d:nn.init.xavier_uniform_(m.weight)
net.apply(init_weight)

其他部分同之前一样。

训练过程

在这里插入图片描述

损失函数可视化

在这里插入图片描述

准确率可视化

在这里插入图片描述

总结

数据集收集过程中遇到了部分麻烦,数据集还不够完整。

相关文章:

动手学深度学习:CNN和LeNet

前言 该篇文章记述从零如何实现CNN&#xff0c;以及LeNet对于之前数据集分类的提升效果。 从零实现卷积核 import torch def conv2d(X,k):h,wk.shapeYtorch.zeros((X.shape[0]-h1,X.shape[1]-w1))for i in range(Y.shape[0]):for j in range(Y.shape[1]):Y[i,j](X[i:ih,j:jw…...

C语言 常用系统函数

<string.h> 头文件中的字符串函数 标准库的头文件 <string.h> 中&#xff0c;有三个常用的字符串函数&#xff1a; 函数名 描述 strlen(str) 返回str的长度&#xff0c;类型是 size_t strcpy(str1,str2) 将str2中的字符串复制到str1中 strcat(str1,str2) 将…...

删除排序链表中的重复元素(js实现,LeetCode:83)

看到这道题的第一反应是使用快慢指针&#xff0c;之前做过类似的题&#xff1a;删除有序数组中的重复项&#xff08;js实现&#xff0c;LeetCode&#xff1a;26&#xff09;原理都是一样,区别是这题需要将重复项删除&#xff0c;所以只需要走一遍单循环就可以实现 /*** Defini…...

【嵌入式学习】如何利用gitee管理记录学习内容

# 新建git仓库并连接到本地 ## 查看本地是否下载git git --version ## 全局配置git git config --global user.name "你的用户名" git config --global user.email "你的邮箱" git config --global credential.helper store ## 初始化本地仓库 git…...

C#入门学习记录(四)C#运算符详解:掌握算术与条件运算符的必备技巧+字符串拼接

一、运算符概述 运算符是程序进行数学运算、逻辑判断的核心工具&#xff0c;C#中的运算符分为&#xff1a; 算术运算符 → 数学计算&#xff08; - * / %&#xff09; 条件运算符 → 三目判断&#xff08;?:&#xff09; 关系运算符 → 比较大小&#xff08;> < &#…...

单片机自学总结

自从工作以来&#xff0c;一直努力耕耘单片机&#xff0c;至今&#xff0c;颇有收获。从51单片机&#xff0c;PIC单片机&#xff0c;直到STM32&#xff0c;以及RTOS和Linux&#xff0c;几乎天天在搞:51单片机&#xff0c;STM8S207单片机&#xff0c;PY32F003单片机&#xff0c;…...

Unity教程(二十二)技能系统 分身技能

Unity开发2D类银河恶魔城游戏学习笔记 Unity教程&#xff08;零&#xff09;Unity和VS的使用相关内容 Unity教程&#xff08;一&#xff09;开始学习状态机 Unity教程&#xff08;二&#xff09;角色移动的实现 Unity教程&#xff08;三&#xff09;角色跳跃的实现 Unity教程&…...

HTML5扫雷游戏开发实战

HTML5扫雷游戏开发实战 这里写目录标题 HTML5扫雷游戏开发实战项目介绍技术栈项目架构1. 游戏界面设计2. 核心类设计 核心功能实现1. 游戏初始化2. 地雷布置算法3. 数字计算逻辑4. 扫雷功能实现 性能优化1. DOM操作优化2. 算法优化 项目亮点技术难点突破1. 首次点击保护2. 连锁…...

【Git学习笔记】Git分支管理策略及其结构原理分析

【Git学习笔记】Git分支管理策略及其结构原理分析 &#x1f525;个人主页&#xff1a;大白的编程日记 &#x1f525;专栏&#xff1a;Git学习笔记 文章目录 【Git学习笔记】Git分支管理策略及其结构原理分析前言一.合并冲突二. 分支管理策略2.1 分支策略2.2 bug分支2.3 删除临…...

Spring Cloud Alibaba Nacos 2023.X 配置问题

文章目录 问题现象&#xff08;一&#xff09;解决方法&#xff08;一&#xff09;问题现象&#xff08;二&#xff09;解决方法&#xff08;二&#xff09;问题现象&#xff08;三&#xff09;解决方法&#xff08;三&#xff09; 问题现象&#xff08;一&#xff09; Spring…...

厨卫行业供应链产销协同前中后大平台现状需求分析报告+P120(120页PPT)(文末有下载方式)

资料解读&#xff1a;厨卫行业供应链产销协同前中后大平台现状需求分析报告 详细资料请看本解读文章的最后内容。在当前厨卫行业竞争激烈的市场环境下&#xff0c;企业的发展战略和业务模式创新至关重要。本次解读的报告围绕某厨卫企业展开&#xff0c;深入探讨其供应链产销协同…...

我在哪,要去哪

在直播间听到一首好听的歌《我在哪&#xff0c;要去哪》-汤倩。 遇见的事&#xff1a;21~24号抽调去招生。 感受到的情绪&#xff1a;公假吗&#xff1f;给工作量吗&#xff1f;月工作量不够扣钱吗&#xff1f;报销方便吗&#xff1f;有事情&#xff0c;从来不解决后顾&#x…...

SpringBoot-2整合MyBatis以及基本的使用方法

目录 1.引入依赖 2.数据库表的创建 3.数据源的配置 4.编写pojo类 5.编写controller类 6.编写接口 7.编写接口的实现类 8.编写mapper 1.引入依赖 在pom.xml引入依赖 <!-- mysql--><dependency><groupId>com.mysql</groupId><artifac…...

本周安全速报(2025.3.11~3.17)

合规速递 01 瑞士出台新规&#xff1a;关基设施遭遇网络攻击需在24小时内上报 原文: https://www.bleepingcomputer.com/news/security/swiss-critical-sector-faces-new-24-hour-cyberattack-reporting-rule/ 新规要求&#xff0c;关键基础设施组织发现网络攻击后&…...

【css酷炫效果】纯CSS实现瀑布流加载动画

【css酷炫效果】纯CSS实现瀑布流加载动画 缘创作背景html结构css样式完整代码基础版进阶版(无限往复加载) 效果图 想直接拿走的老板&#xff0c;链接放在这里&#xff1a;https://download.csdn.net/download/u011561335/90492012 缘 创作随缘&#xff0c;不定时更新。 创作…...

咖啡点单小程序毕业设计(JAVA+SpringBoot+微信小程序+完整源码+论文)

✌全网粉丝20W,csdn特邀作者、博客专家、CSDN新星计划导师、java领域优质创作者,博客之星、掘金/华为云/阿里云/InfoQ等平台优质作者、专注于Java技术领域和毕业项目实战✌ &#x1f345;文末获取项目下载方式&#x1f345; 一、项目背景介绍&#xff1a; 随着社会的快速发展和…...

网络编程套接字【端口号/TCPUDP/网络字节序/socket编程接口/UDPTCP网络实验】

网络编程套接字 0. 前言1. 认识端口号2. 认识TCP和UDP协议3. 网络字节序4. socket编程接口5. 实现一个简单的UDP网络程序5.1 需求分析5.2 头文件准备5.3 服务器端设计5.4 客户端设计5.5 本地测试5.6 跨网络测试5.7 UDP小应用——客户端输入命令&#xff0c;服务器端执行 6. 地址…...

【c++】内存序 和 内存一致性模型

c 11 中为了支持并发&#xff0c;定义了内存序和内存一致性模型。这个概念听起来非常高深&#xff0c;好像是在多线程编程领域浸淫多年之后的神级程序员才能搞明白&#xff0c;并用明白的东西。 本文尝试用最简单的方式说清楚这个概念。因为这个概念真的超级简单&#xff0c;大…...

7-字符串

1-ASCII 0-9 对应 48-57 A-Z 对应 65-90 a-z 对应 97-122 2-字符数组 字符变量存储单个字符 字符数组存储多个字符 字符串就是字符数组加上结束符 ’ \0 ’ #include <iostream> using namespace std; int main(){//是字符数组&#xff0c;不是字符串char a1[]{C,,};…...

DeepSeek 3FS 与 JuiceFS:架构与特性比较

近期&#xff0c;DeepSeek 开源了其文件系统 Fire-Flyer File System (3FS)&#xff0c;使得文件系统这一有着 70 多年历时的“古老”的技术&#xff0c;又获得了各方的关注。在 AI 业务中&#xff0c;企业需要处理大量的文本、图像、视频等非结构化数据&#xff0c;还需要应对…...

数据结构 -- 二叉树的存储结构

二叉树的存储结构 顺序存储 #define MaxSize 100 struct TreeNode{ElemType value; //结点中的数据元素bool isEmpty; //结点元素是否为空 };//定义一个长度为MaxSize的数组t&#xff0c;按照从上至下、从左至右的顺序依次完成存储完全二叉树中的各个节点 TreeNode t[MaxSi…...

Unity WebGL项目访问时自动全屏

Unity WebGL项目访问时自动全屏 打开TemplateData/style.css文件 在文件最下方添加红色框内的两行代码 使用vscode或者其他编辑器打开index.html 将按钮注释掉&#xff0c;并且更改为默认全屏...

vue computed 计算属性简述

Vue 的 ‌计算属性&#xff08;Computed Properties&#xff09;‌ 是 Vue 实例中一种特殊的属性&#xff0c;用于‌声明式地定义依赖其他数据动态计算得出的值‌。它的核心优势在于能够自动追踪依赖关系&#xff0c;并缓存计算结果&#xff0c;避免重复计算&#xff0c;提升性…...

破局者登场:中国首款AI原生IDE Trae深度解析--开启人机协同编程新纪元

摘要 字节跳动于2025年3月3日正式发布中国首款AI原生集成开发环境Trae国内版&#xff0c;以动态协作、全场景AI赋能及本土化适配为核心优势。Trae内置Doubao-1.5-pro与DeepSeek R1/V3双引擎&#xff0c;支持基于自然语言生成端到端代码框架、实时上下文感知与智能Bug修复&…...

如何通过Python的`requests`库接入DeepSeek智能API

本文将详细介绍如何通过Python的requests库接入DeepSeek智能API&#xff0c;实现数据交互与智能对话功能。文章涵盖环境配置、API调用、参数解析、错误处理等全流程内容&#xff0c;并提供完整代码示例。 一、环境准备与API密钥获取 1. 注册DeepSeek账号 访问DeepSeek官网&am…...

【C++】std::make_shared 详解

std::make_shared 详解 1. std::make_shared 简介 std::make_shared 是 C11 标准引入的一个函数模板&#xff0c;用于创建 std::shared_ptr 对象&#xff0c;并高效地分配和管理对象的内存。它比直接使用 std::shared_ptr 构造函数 std::shared_ptr<T>(new T(...)) 具有…...

【NoSql】Redis

Ubuntu22.04版本编译安装 Redis Redis version7.4.2 #解压源码包 tar -zxvf redis-stable.tar.gz cd redis-stable/ make make install安装好了后&#xff0c;可执行文件默认会放入/usr/local/bin/ rootluobozi:~ ls /usr/local/bin/* /usr/local/bin/redis-cli /usr/local/…...

ClickHouse Docker 容器迁移指南:从测试环境到离线正式环境

ClickHouse Docker 容器迁移指南&#xff1a;从测试环境到离线正式环境 在实际开发和运维过程中&#xff0c;我们经常需要将测试环境中的服务迁移到正式环境&#xff0c;尤其是当正式环境处于离线状态时&#xff0c;这种迁移会变得更加复杂。本文将详细介绍如何将运行在 Docke…...

C# WPF编程-Menu

C# WPF编程-Menu 布局&#xff1a;代码&#xff1a;效果 在WPF&#xff08;Windows Presentation Foundation&#xff09;中&#xff0c;Menu控件用于创建下拉菜单或上下文菜单&#xff0c;它提供了丰富的定制选项来满足不同的应用需求。下面将介绍如何在WPF应用程序中使用Menu…...

利用Python爬虫获取Shopee(虾皮)商品详情:实战指南

在跨境电商领域&#xff0c;Shopee&#xff08;虾皮&#xff09;作为东南亚及台湾地区领先的电商平台&#xff0c;拥有海量的商品信息。无论是进行市场调研、数据分析&#xff0c;还是寻找热门商品&#xff0c;获取Shopee商品详情都是一项极具价值的任务。然而&#xff0c;手动…...