当前位置: 首页 > news >正文

《动手学深度学习 Pytorch版》 6.6 卷积神经网络

import torch
from torch import nn
from d2l import torch as d2l

6.6.1 LeNet

LetNet-5 由两个部分组成:

- 卷积编码器:由两个卷积核组成。
- 全连接层稠密块:由三个全连接层组成。

模型结构如下流程图(每个卷积块由一个卷积层、一个 sigmoid 激活函数和平均汇聚层组成):

全连接层(10)

↑ \uparrow

全连接层(84)

↑ \uparrow

全连接层(120)

↑ \uparrow

2 × 2 2\times2 2×2平均汇聚层,步幅2

↑ \uparrow

5 × 5 5\times5 5×5卷积层(16)

↑ \uparrow

2 × 2 2\times2 2×2平均汇聚层,步幅2

↑ \uparrow

5 × 5 5\times5 5×5卷积层(6),填充2

↑ \uparrow

输入图像( 28 × 28 28\times28 28×28 单通道)
net = nn.Sequential(nn.Conv2d(1, 6, kernel_size=5, padding=2), nn.Sigmoid(),nn.AvgPool2d(kernel_size=2, stride=2),nn.Conv2d(6, 16, kernel_size=5), nn.Sigmoid(),nn.AvgPool2d(kernel_size=2, stride=2),nn.Flatten(),nn.Linear(16 * 5 * 5, 120), nn.Sigmoid(),nn.Linear(120, 84), nn.Sigmoid(),nn.Linear(84, 10))
X = torch.rand(size=(1, 1, 28, 28), dtype=torch.float32)  # 生成测试数据
for layer in net:X = layer(X)print(layer.__class__.__name__,'output shape: \t',X.shape)  # 确保模型各层数据正确
Conv2d output shape: 	 torch.Size([1, 6, 28, 28])
Sigmoid output shape: 	 torch.Size([1, 6, 28, 28])
AvgPool2d output shape: 	 torch.Size([1, 6, 14, 14])
Conv2d output shape: 	 torch.Size([1, 16, 10, 10])
Sigmoid output shape: 	 torch.Size([1, 16, 10, 10])
AvgPool2d output shape: 	 torch.Size([1, 16, 5, 5])
Flatten output shape: 	 torch.Size([1, 400])
Linear output shape: 	 torch.Size([1, 120])
Sigmoid output shape: 	 torch.Size([1, 120])
Linear output shape: 	 torch.Size([1, 84])
Sigmoid output shape: 	 torch.Size([1, 84])
Linear output shape: 	 torch.Size([1, 10])

6.6.2 模型训练

batch_size = 256
train_iter, test_iter = d2l.load_data_fashion_mnist(batch_size=batch_size)  # 仍使用经典的 Fashion-MNIST 数据集
def evaluate_accuracy_gpu(net, data_iter, device=None): #@save"""使用GPU计算模型在数据集上的精度"""if isinstance(net, nn.Module):net.eval()  # 设置为评估模式if not device:device = next(iter(net.parameters())).devicemetric = d2l.Accumulator(2)  # 生成一个有两个元素的列表,使用 add 将会累加到对应的元素上with torch.no_grad():for X, y in data_iter:# 为了使用 GPU,需要将数据移动到 GPU 上if isinstance(X, list):X = [x.to(device) for x in X]else:X = X.to(device)y = y.to(device)metric.add(d2l.accuracy(net(X), y), y.numel())  # 累加(正确预测的数量,总预测的数量)return metric[0] / metric[1]  # 正确率
#@save
def train_ch6(net, train_iter, test_iter, num_epochs, lr, device):"""用GPU训练模型(在第六章定义)"""def init_weights(m):  # 使用 Xavier 初始化权重if type(m) == nn.Linear or type(m) == nn.Conv2d:nn.init.xavier_uniform_(m.weight)net.apply(init_weights)print('training on', device)net.to(device)  # 移动数据到GPUoptimizer = torch.optim.SGD(net.parameters(), lr=lr)loss = nn.CrossEntropyLoss()animator = d2l.Animator(xlabel='epoch', xlim=[1, num_epochs],legend=['train loss', 'train acc', 'test acc'])timer, num_batches = d2l.Timer(), len(train_iter)for epoch in range(num_epochs):# 训练损失之和,训练准确率之和,样本数metric = d2l.Accumulator(3)net.train()for i, (X, y) in enumerate(train_iter):timer.start()optimizer.zero_grad()X, y = X.to(device), y.to(device)y_hat = net(X)l = loss(y_hat, y)l.backward()optimizer.step()with torch.no_grad():metric.add(l * X.shape[0], d2l.accuracy(y_hat, y), X.shape[0])timer.stop()train_l = metric[0] / metric[2]train_acc = metric[1] / metric[2]if (i + 1) % (num_batches // 5) == 0 or i == num_batches - 1:animator.add(epoch + (i + 1) / num_batches,(train_l, train_acc, None))test_acc = evaluate_accuracy_gpu(net, test_iter)animator.add(epoch + 1, (None, None, test_acc))print(f'loss {train_l:.3f}, train acc {train_acc:.3f}, 'f'test acc {test_acc:.3f}')print(f'{metric[2] * num_epochs / timer.sum():.1f} examples/sec 'f'on {str(device)}')
lr, num_epochs = 0.9, 10
train_ch6(net, train_iter, test_iter, num_epochs, lr, d2l.try_gpu())
loss 0.471, train acc 0.820, test acc 0.815
40056.7 examples/sec on cuda:0

在这里插入图片描述

练习

(1)将平均汇聚层替换为最大汇聚层,会发生什么?

net_Max = nn.Sequential(nn.Conv2d(1, 6, kernel_size=5, padding=2), nn.Sigmoid(),nn.MaxPool2d(kernel_size=2, stride=2),nn.Conv2d(6, 16, kernel_size=5), nn.Sigmoid(),nn.MaxPool2d(kernel_size=2, stride=2),nn.Flatten(),nn.Linear(16 * 5 * 5, 120), nn.Sigmoid(),nn.Linear(120, 84), nn.Sigmoid(),nn.Linear(84, 10))lr, num_epochs = 0.9, 10
train_ch6(net_Max, train_iter, test_iter, num_epochs, lr, d2l.try_gpu())
loss 0.422, train acc 0.844, test acc 0.671
31151.6 examples/sec on cuda:0

在这里插入图片描述

几乎无区别


(2)尝试构建一个基于 LeNet 的更复杂网络,以提高其精准性。

a. 调节卷积窗口的大小。
b. 调整输出通道的数量。
c. 调整激活函数(如 ReLU)。
d. 调整卷积层的数量。
e. 调整全连接层的数量。
f. 调整学习率和其他训练细节(例如,初始化和轮数)。
net_Best = nn.Sequential(nn.Conv2d(1, 8, kernel_size=5, padding=2), nn.ReLU(),nn.AvgPool2d(kernel_size=2, stride=2),nn.Conv2d(8, 16, kernel_size=3, padding=1), nn.ReLU(),nn.AvgPool2d(kernel_size=2, stride=2),nn.Conv2d(16, 32, kernel_size=3, padding=1), nn.ReLU(),nn.AvgPool2d(kernel_size=2, stride=2),nn.Flatten(),nn.Linear(32 * 3 * 3, 128), nn.ReLU(),nn.Linear(128, 64), nn.ReLU(),nn.Linear(64, 32), nn.ReLU(),nn.Linear(32, 10)
)
lr, num_epochs = 0.4, 10
train_ch6(net_Best, train_iter, test_iter, num_epochs, lr, d2l.try_gpu())
loss 0.344, train acc 0.869, test acc 0.854
32868.3 examples/sec on cuda:0

在这里插入图片描述


(3)在 MNIST 数据集上尝试以上改进后的网络。

import torchvision
from torch.utils import data
from torchvision import transformstrans = transforms.ToTensor()
mnist_train = torchvision.datasets.MNIST(root="../data", train=True, transform=trans, download=True)
mnist_test = torchvision.datasets.MNIST(root="../data", train=False, transform=trans, download=True)
train_iter2 = data.DataLoader(mnist_train, batch_size, shuffle=True,num_workers=d2l.get_dataloader_workers())
test_iter2 = data.DataLoader(mnist_test, batch_size, shuffle=True,num_workers=d2l.get_dataloader_workers())lr, num_epochs = 0.4, 5  # 大约 6 轮往后直接就爆炸
train_ch6(net_Best, train_iter2, test_iter2, num_epochs, lr, d2l.try_gpu())
loss 0.049, train acc 0.985, test acc 0.986
26531.1 examples/sec on cuda:0

在这里插入图片描述


(4)显示不同输入(例如,毛衣和外套)时 LetNet 第一层和第二层的激活值。

for X, y in test_iter:breakx_first_Sigmoid_layer = net[0:2](X)[0:9, 1, :, :]
d2l.show_images(x_first_Sigmoid_layer.reshape(9, 28, 28).cpu().detach(), 1, 9)
x_second_Sigmoid_layer = net[0:5](X)[0:9, 1, :, :]
d2l.show_images(x_second_Sigmoid_layer.reshape(9, 10, 10).cpu().detach(), 1, 9)
d2l.plt.show()


在这里插入图片描述

在这里插入图片描述

相关文章:

《动手学深度学习 Pytorch版》 6.6 卷积神经网络

import torch from torch import nn from d2l import torch as d2l6.6.1 LeNet LetNet-5 由两个部分组成: - 卷积编码器:由两个卷积核组成。 - 全连接层稠密块:由三个全连接层组成。模型结构如下流程图(每个卷积块由一个卷积层、…...

【微信小程序】项目初始化

| var() CSS 函数可以插入一个自定义属性(有时也被称为“CSS 变量”)的值,用来代替非自定义 属性中值的任何部分。 1.初始化样式与颜色 view,text{box-sizing: border-box; } page{--themColor:#ad905c;--globalColor:#18191b;--focusColor…...

C#,《小白学程序》第二十六课:大数乘法(BigInteger Multiply)的Toom-Cook 3算法及源程序

凑数的&#xff0c;仅供参考。 1 文本格式 /// <summary> /// 《小白学程序》第二十六课&#xff1a;大数&#xff08;BigInteger&#xff09;的Toom-Cook 3乘法 /// Toom-Cook 3-Way Multiplication /// </summary> /// <param name"a"></par…...

destoon自定义一个archiver内容文档

在archiver目录建立以下代码&#xff1a; <?php define(DT_REWRITE, true); require ../common.inc.php; $EXT[archiver_enable] or dheader(DT_PATH); //$DT_BOT or dheader(DT_PATH); $N $M $T array(); $mid or $mid 5; $vmid $list 0; foreach($MODULE as $k>…...

5-1 Dataset和DataLoader

Pytorch通常使用Dataset和DataLoader这两个工具类来构建数据管道。 Dataset定义了数据集的内容&#xff0c;它相当于一个类似列表的数据结构&#xff0c;具有确定的长度&#xff0c;能够用索引获取数据集中的元素。 而DataLoader定义了按batch加载数据集的方法&#xff0c;它是…...

IDEA创建完Maven工程后,右下角一直显示正在下载Maven插件

原因&#xff1a; 这是由于新建的Maven工程&#xff0c;IDEA会用它内置的默认的Maven版本&#xff0c;使用国外的网站下载Maven所需的插件&#xff0c;速度很慢 。 解决方式&#xff1a; 每次创建 Project 后都需要设置 Maven 家目录位置&#xff08;就是我们自己下载的Mav…...

最新清理删除Mac电脑内存空间方法教程

Mac电脑使用的时间越久&#xff0c;系统的运行就会变的越卡顿&#xff0c;这是Mac os会出现的正常现象&#xff0c;卡顿的原因主要是系统缓存文件占用了较多的磁盘空间&#xff0c;或者Mac的内存空间已满。如果你的Mac运行速度变慢&#xff0c;很有可能是因为磁盘内存被过度占用…...

【调试经验】MySQL - fatal error: mysql/mysql.h: 没有那个文件或目录

机器环境&#xff1a; Ubuntu 22.04.3 LTS 报错问题 在编译一个项目时出现了一段SQL报错&#xff1a; CGImysql/sql_connection_pool.cpp:1:10: fatal error: mysql/mysql.h: 没有那个文件或目录 1 | #include <mysql/mysql.h> | ^~~~~~~~~~~~~~~ c…...

腾讯mini项目-【指标监控服务重构】2023-08-12

今日已办 Watermill Handler 将 4 个阶段的逻辑处理定义为 Handler 测试发现&#xff0c;添加的 handler 会被覆盖掉&#xff0c;故考虑添加为 middleware 且 4 个阶段的处理逻辑针对不同 topic 是相同的。 参考https://watermill.io/docs/messages-router/实现不同topic&am…...

kubeadm部署k8sv1.24使用cri-docker做为CRI

目的 测试使用cri-docker做为containerd和docker的中间层垫片。 规划 IP系统主机名10.0.6.5ubuntu 22.04.3 jammymaster01.kktb.org10.0.6.6ubuntu 22.04.3 jammymaster02.kktb.org10.0.6.7ubuntu 22.04.3 jammymaster03.kktb.org 配置 步骤&#xff1a; 系统优化 禁用sw…...

在c#中使用CancellationToken取消任务

目录 &#x1f680;介绍&#xff1a; &#x1f424;简单举例 &#x1f680;IsCancellationRequested &#x1f680;ThrowIfCancellationRequested &#x1f424;在控制器中使用 &#x1f680;通过异步方法的参数使用cancellationToken &#x1f680;api结合ThrowIfCancel…...

【项目经验】:elementui多选表格默认选中

一.需求 在页面刚打开就默认选中指定项。 二.方法Table Methods toggleRowSelection用于多选表格&#xff0c;切换某一行的选中状态&#xff0c;如果使用了第二个参数&#xff0c;则是设置这一行选中与否&#xff08;selected 为 true 则选中&#xff09;row, selected 详细…...

外星人入侵游戏-(创新版)

&#x1f308;write in front&#x1f308; &#x1f9f8;大家好&#xff0c;我是Aileen&#x1f9f8;.希望你看完之后&#xff0c;能对你有所帮助&#xff0c;不足请指正&#xff01;共同学习交流. &#x1f194;本文由Aileen_0v0&#x1f9f8; 原创 CSDN首发&#x1f412; 如…...

HTML 学习笔记(基础)

它是超文本标记语言&#xff0c;由一大堆约定俗成的标签组成&#xff0c;而其标签里一般又有一些属性值可以设置。 W3C标准&#xff1a;网页主要三大部分 结构&#xff1a;HTML表现&#xff1a;CSS行为&#xff1a;JavaScript <!DOCTYPE html> <html lang"zh-…...

最小二乘法

Least Square Method 1、相关的矩阵公式2、线性回归3、最小二乘法3.1、损失函数&#xff08;Loss Function&#xff09;3.2、多维空间的损失函数3.3、解析法求解3.4、梯度下降法求解 1、相关的矩阵公式 P r e c o n d i t i o n : ξ ∈ R n , A ∈ R n ∗ n i : σ A ξ σ ξ…...

使用stelnet进行安全的远程管理

1. telnet有哪些不足&#xff1f; 2.ssh如何保证数据传输安全&#xff1f; 需求&#xff1a;远程telnet管理设备 用户定义需要在AAA模式下&#xff1a; 开启远程登录的服务&#xff1a;定义vty接口 然后从R2登录&#xff1a;是可以登录的 同理R3登录&#xff1a; 在R1也可以查…...

python 二手车数据分析以及价格预测

二手车交易信息爬取、数据分析以及交易价格预测 引言一、数据爬取1.1 解析数据1.2 编写代码爬1.2.1 获取详细信息1.2.2 数据处理 二、数据分析2.1 统计分析2.2 可视化分析 三、价格预测3.1 价格趋势分析(特征分析)3.2 价格预测 引言 本文着眼于车辆信息&#xff0c;结合当下较…...

JAVA医药进销存管理系统(附源码+调试)

JAVA医药进销存管理系统 功能描述 &#xff08;1&#xff09;登录模块&#xff1a;登录信息等存储在数据库中 &#xff08;2&#xff09;基本信息模块&#xff1a;分为药品信息模块、客户情况模块、供应商情况模块&#xff1b; &#xff08;3&#xff09;业务管理模块&#x…...

H5 <blockquote> 标签

主要应用于&#xff1a;内容引用 标签定义及使用说明 <blockquote> 标签定义摘自另一个源的块引用。 浏览器通常会对 <blockquote> 元素进行缩进。 提示和注释 提示&#xff1a;如果标记是不需要段落分隔的短引用&#xff0c;请使用 <q>。 HTML 4.01 与 H…...

nginx配置指南

nginx.conf配置 找到Nginx的安装目录下的nginx.conf文件&#xff0c;该文件负责Nginx的基础功能配置。 配置文件概述 Nginx的主配置文件(conf/nginx.conf)按以下结构组织&#xff1a; 配置块功能描述全局块与Nginx运行相关的全局设置events块与网络连接有关的设置http块代理…...

【网络】每天掌握一个Linux命令 - iftop

在Linux系统中&#xff0c;iftop是网络管理的得力助手&#xff0c;能实时监控网络流量、连接情况等&#xff0c;帮助排查网络异常。接下来从多方面详细介绍它。 目录 【网络】每天掌握一个Linux命令 - iftop工具概述安装方式核心功能基础用法进阶操作实战案例面试题场景生产场景…...

日语学习-日语知识点小记-构建基础-JLPT-N4阶段(33):にする

日语学习-日语知识点小记-构建基础-JLPT-N4阶段(33):にする 1、前言(1)情况说明(2)工程师的信仰2、知识点(1) にする1,接续:名词+にする2,接续:疑问词+にする3,(A)は(B)にする。(2)復習:(1)复习句子(2)ために & ように(3)そう(4)にする3、…...

多场景 OkHttpClient 管理器 - Android 网络通信解决方案

下面是一个完整的 Android 实现&#xff0c;展示如何创建和管理多个 OkHttpClient 实例&#xff0c;分别用于长连接、普通 HTTP 请求和文件下载场景。 <?xml version"1.0" encoding"utf-8"?> <LinearLayout xmlns:android"http://schemas…...

ServerTrust 并非唯一

NSURLAuthenticationMethodServerTrust 只是 authenticationMethod 的冰山一角 要理解 NSURLAuthenticationMethodServerTrust, 首先要明白它只是 authenticationMethod 的选项之一, 并非唯一 1 先厘清概念 点说明authenticationMethodURLAuthenticationChallenge.protectionS…...

WordPress插件:AI多语言写作与智能配图、免费AI模型、SEO文章生成

厌倦手动写WordPress文章&#xff1f;AI自动生成&#xff0c;效率提升10倍&#xff01; 支持多语言、自动配图、定时发布&#xff0c;让内容创作更轻松&#xff01; AI内容生成 → 不想每天写文章&#xff1f;AI一键生成高质量内容&#xff01;多语言支持 → 跨境电商必备&am…...

C++.OpenGL (10/64)基础光照(Basic Lighting)

基础光照(Basic Lighting) 冯氏光照模型(Phong Lighting Model) #mermaid-svg-GLdskXwWINxNGHso {font-family:"trebuchet ms",verdana,arial,sans-serif;font-size:16px;fill:#333;}#mermaid-svg-GLdskXwWINxNGHso .error-icon{fill:#552222;}#mermaid-svg-GLd…...

IT供电系统绝缘监测及故障定位解决方案

随着新能源的快速发展&#xff0c;光伏电站、储能系统及充电设备已广泛应用于现代能源网络。在光伏领域&#xff0c;IT供电系统凭借其持续供电性好、安全性高等优势成为光伏首选&#xff0c;但在长期运行中&#xff0c;例如老化、潮湿、隐裂、机械损伤等问题会影响光伏板绝缘层…...

C#中的CLR属性、依赖属性与附加属性

CLR属性的主要特征 封装性&#xff1a; 隐藏字段的实现细节 提供对字段的受控访问 访问控制&#xff1a; 可单独设置get/set访问器的可见性 可创建只读或只写属性 计算属性&#xff1a; 可以在getter中执行计算逻辑 不需要直接对应一个字段 验证逻辑&#xff1a; 可以…...

JavaScript 数据类型详解

JavaScript 数据类型详解 JavaScript 数据类型分为 原始类型&#xff08;Primitive&#xff09; 和 对象类型&#xff08;Object&#xff09; 两大类&#xff0c;共 8 种&#xff08;ES11&#xff09;&#xff1a; 一、原始类型&#xff08;7种&#xff09; 1. undefined 定…...

Python Einops库:深度学习中的张量操作革命

Einops&#xff08;爱因斯坦操作库&#xff09;就像给张量操作戴上了一副"语义眼镜"——让你用人类能理解的方式告诉计算机如何操作多维数组。这个基于爱因斯坦求和约定的库&#xff0c;用类似自然语言的表达式替代了晦涩的API调用&#xff0c;彻底改变了深度学习工程…...