动手学深度学习—使用块的网络VGG(代码详解)
目录
- 1. VGG块
- 2. VGG网络
- 3. 训练模型
1. VGG块
经典卷积神经网络的基本组成部分是下面的这个序列:
1.带填充以保持分辨率的卷积层;
2.非线性激活函数,如ReLU;
3.汇聚层,如最大汇聚层。
定义网络块,便于我们重复构建某些网络架构,不仅利于代码编写与阅读也利于后面参数的优化
"""定义了一个名为vgg_block的函数来实现一个VGG块:1、卷积层的数量num_convs2、输入通道的数量in_channels 3、输出通道的数量out_channels
"""
import torch
from torch import nn
from d2l import torch as d2l# 定义vgg块,(卷积层数,输入通道,输出通道)
def vgg_block(num_convs, in_channels, out_channels):# 创建空网络结果,之后通过循环操作使用append函数进行添加layers = []# 循环操作,添加卷积层和非线性激活层for _ in range(num_convs):layers.append(nn.Conv2d(in_channels, out_channels,kernel_size=3, padding=1))layers.append(nn.ReLU())in_channels = out_channels# 最后添加最大值汇聚层layers.append(nn.MaxPool2d(kernel_size=2, stride=2))return nn.Sequential(*layers)
2. VGG网络

由于会重复用到卷积层、激活函数ReLU和汇聚层,我们将这三个组合成一个块,每次引用这个块来构建网络模型。
通过定义VGG块,使得重复的网络结构实现起来更加容易,也利于代码阅读。
# 原VGG网络有5个卷积块,前两个有一个卷积层,后三个块有两个卷积层
# 该网络使用8个卷积层和3个全连接层,因此它通常被称为VGG-11# (卷积层数,输出通道数)
conv_arch = ((1, 64), (1, 128), (2, 256), (2, 512), (2, 512))
实现VGG-11:使用8个卷积层和3个全连接层
# 通过for循环实现VGG-11
def vgg(conv_arch):# 定义空网络结构conv_blks = []in_channels = 1# 卷积层部分for (num_convs, out_channels) in conv_arch:# 添加vgg块conv_blks.append(vgg_block(num_convs, in_channels, out_channels))# 下一层输入通道数=当前层输出通道数in_channels = out_channelsreturn nn.Sequential(*conv_blks, nn.Flatten(),# 全连接层部分nn.Linear(out_channels * 7 * 7, 4096), nn.ReLU(), nn.Dropout(0.5),nn.Linear(4096, 4096), nn.ReLU(), nn.Dropout(0.5),nn.Linear(4096, 10))net = vgg(conv_arch)
构建一个高度和宽度为224的单通道数据样本,以观察每个层输出的形状
# 构建一个高度和宽度为224的单通道数据样本,以观察每个层输出的形状
X = torch.randn(size=(1, 1, 224, 224))
for blk in net:X = blk(X)print(blk.__class__.__name__, 'output shape:\t', X.shape)
每一层的输出形状

3. 训练模型
构建了一个通道数较少的网络,足够用于训练Fashion-MNIST数据集
# 构建了一个通道数较少的网络,足够用于训练Fashion-MNIST数据集
ratio = 4
# //为整除
small_conv_arch = [(pair[0], pair[1] // 4) for pair in conv_arch]
net = vgg(small_conv_arch)
定义精度评估函数
"""定义精度评估函数:1、将数据集复制到显存中2、通过调用accuracy计算数据集的精度
"""
def evaluate_accuracy_gpu(net, data_iter, device=None): #@save# 判断net是否属于torch.nn.Module类if isinstance(net, nn.Module):net.eval()# 如果不在参数选定的设备,将其传输到设备中if not device:device = next(iter(net.parameters())).device# Accumulator是累加器,定义两个变量:正确预测的数量,总预测的数量。metric = d2l.Accumulator(2)with torch.no_grad():for X, y in data_iter:# 将X, y复制到设备中if isinstance(X, list):# BERT微调所需的(之后将介绍)X = [x.to(device) for x in X]else:X = X.to(device)y = y.to(device)# 计算正确预测的数量,总预测的数量,并存储到metric中metric.add(d2l.accuracy(net(X), y), y.numel())return metric[0] / metric[1]
定义GPU 训练函数
"""定义GPU训练函数:1、为了使用gpu,首先需要将每一小批量数据移动到指定的设备(例如GPU)上;2、使用Xavier随机初始化模型参数;3、使用交叉熵损失函数和小批量随机梯度下降。
"""
#@save
def train_ch6(net, train_iter, test_iter, num_epochs, lr, device):"""用GPU训练模型(在第六章定义)"""# 定义初始化参数,对线性层和卷积层生效def init_weights(m):if type(m) == nn.Linear or type(m) == nn.Conv2d:nn.init.xavier_uniform_(m.weight)net.apply(init_weights)# 在设备device上进行训练print('training on', device)net.to(device)# 优化器:随机梯度下降optimizer = torch.optim.SGD(net.parameters(), lr=lr)# 损失函数:交叉熵损失函数loss = nn.CrossEntropyLoss()# Animator为绘图函数animator = d2l.Animator(xlabel='epoch', xlim=[1, num_epochs],legend=['train loss', 'train acc', 'test acc'])# 调用Timer函数统计时间timer, num_batches = d2l.Timer(), len(train_iter)for epoch in range(num_epochs):# Accumulator(3)定义3个变量:损失值,正确预测的数量,总预测的数量metric = d2l.Accumulator(3)net.train()# enumerate() 函数用于将一个可遍历的数据对象for i, (X, y) in enumerate(train_iter):timer.start() # 进行计时optimizer.zero_grad() # 梯度清零X, y = X.to(device), y.to(device) # 将特征和标签转移到devicey_hat = net(X)l = loss(y_hat, y) # 交叉熵损失l.backward() # 进行梯度传递返回optimizer.step()with torch.no_grad():# 统计损失、预测正确数和样本数metric.add(l * X.shape[0], d2l.accuracy(y_hat, y), X.shape[0])timer.stop() # 计时结束train_l = metric[0] / metric[2] # 计算损失train_acc = metric[1] / metric[2] # 计算精度# 进行绘图if (i + 1) % (num_batches // 5) == 0 or i == num_batches - 1:animator.add(epoch + (i + 1) / num_batches,(train_l, train_acc, None))# 测试精度test_acc = evaluate_accuracy_gpu(net, test_iter) animator.add(epoch + 1, (None, None, test_acc))# 输出损失值、训练精度、测试精度print(f'loss {train_l:.3f}, train acc {train_acc:.3f},'f'test acc {test_acc:.3f}')# 设备的计算能力print(f'{metric[2] * num_epochs / timer.sum():.1f} examples/sec'f'on {str(device)}')

进行训练
# 学习率略高
lr, num_epochs, batch_size = 0.05, 10, 128
train_iter, test_iter = d2l.load_data_fashion_mnist(batch_size, resize=224)
d2l.train_ch6(net, train_iter, test_iter, num_epochs, lr, d2l.try_gpu())

块的使用导致网络定义的非常简洁。使用块可以有效地设计复杂的网络。
相关文章:
动手学深度学习—使用块的网络VGG(代码详解)
目录 1. VGG块2. VGG网络3. 训练模型 1. VGG块 经典卷积神经网络的基本组成部分是下面的这个序列: 1.带填充以保持分辨率的卷积层; 2.非线性激活函数,如ReLU; 3.汇聚层,如最大汇聚层。 定义网络块,便于我…...
性能优化:JIT即时编译与AOT提前编译
优质博文:IT-BLOG-CN 一、简介 JIT与AOT的区别: 两种不同的编译方式,主要区别在于是否处于运行时进行编译。 JIT:Just-in-time动态(即时)编译,边运行边编译:在程序运行时,根据算法计算出热点代码…...
抖音同城榜:探索城市新潮流
随着科技的飞速发展,短视频已经成为了人们日常生活中不可或缺的一部分。作为短视频领域的佼佼者,抖音一直致力于为用户带来更丰富、更有趣的短视频内容。抖音同城榜应运而生,成为了最新、最热门的话题聚集地,吸引了大量潮流达人和…...
云表|低代码开发崛起:重新定义企业级应用开发
低代码开发这个概念在近年来越来越受到人们的关注,市场对于低代码的需求也日益增长。据Gartner预测,到2025年,75%的大型企业将使用至少四种低代码/无代码开发工具,用于IT应用开发和公民开发计划。 那么,为什…...
【算法题】2906. 构造乘积矩阵
题目: 给你一个下标从 0 开始、大小为 n * m 的二维整数矩阵 grid ,定义一个下标从 0 开始、大小为 n * m 的的二维矩阵 p。如果满足以下条件,则称 p 为 grid 的 乘积矩阵 : 对于每个元素 p[i][j] ,它的值等于除了 g…...
机器学习基础之《回归与聚类算法(4)—逻辑回归与二分类(分类算法)》
一、什么是逻辑回归 1、逻辑回归(Logistic Regression)是机器学习中的一种分类模型,逻辑回归是一种分类算法,虽然名字中带有回归,但是它与回归之间有一定的联系。由于算法的简单和高效,在实际中应用非常广…...
UWB安全数据通讯STS-加密、身份认证
DW3000系列才能支持UWB安全数据通讯,DW1000不支持 IEEE 802.15.4a没有数据通讯安全保护机制,IEEE 802.15.4z中指定的扩展得到增强(在PHY/RF级别):增添了一个重要特性“扰频时间戳序列(STS)”&a…...
vue3中去除eslint严格模式
vue3中去除eslint严格模式 1、全局搜索:extends 2、一般在package.json或者vue.config.js中,直接删除掉vue/standard,重启项目。(在package.json文件中,编译不允许有注释,所以直接删掉)...
Win10如何彻底关闭wsappx进程?
Win10如何彻底关闭wsappx进程?在Win10电脑中,用户看到了wsappx进程占用了大量的系统资源,所以想结束wsappx进程,提升电脑的运行速度。但是,用户们不知道彻底关闭掉wsappx进程的方法,那么接下来小编就给大家…...
docker 安装 sftpgo
sftpgo 简介 sftpgo 是一个功能齐全且高度可配置的 SFTP 服务器,具有可选的 HTTP/S、FTP/S 和 WebDAV 支持。支持多种存储后端:本地文件系统、加密本地文件系统、S3(兼容)对象存储、Google 云存储、Azure Blob 存储、SFTP。 官…...
threejs (一) 创建一个场景
引入 npm install three import * as THREE from three;const scene new THREE.Scene();或者使用bootCDN复制对应的版本连接 <script src"https://cdn.bootcdn.net/ajax/libs/three.js/0.156.1/three.js"></script>基础知识 场景、相机、渲染器 通过…...
二分查找,求方程多解
1.暴力遍历: 解为两位小数,故0.001的范围肯定可以包含(零点存在) 2.均分为区间长度为1的小区间(由于两解,距离不小于1),一个区间最多一个解 1.防止两边端点都为解 2&…...
代码随想录算法训练营第二十九天 | 回溯算法总结
代码随想录算法训练营第二十九天 | 回溯算法总结 1. 组合问题 1.1 组合问题 在77. 组合中,我们开始用回溯法解决第一道题目:组合问题。 回溯算法跟k层for循环同样是暴力解法,为什么用回溯呢?回溯法的魅力,用递…...
运算方法和运算电路
一、逻辑门电路 1、逻辑门电路基础总结 2、异或运算妙用 3、逻辑常用公式 二、加法器(重点) 1、标志位的生成原理 2、加法器总结 三、多路门选择器,三态门...
计算机网络篇之TCP滑动窗口
文章目录 前言概述 前言 在网络数据传输时,若传输的原始数据包比较大,会将数据包分解成多个数据包进行发送。需要对数据包确认后,才能发送下一个数据包。在等待确认包的这个过程浪费了大量的时间,不过还好TCP引入了滑动窗口的概念…...
本地安装telepresence,访问K8S集群 Mac(m1) 非管理員
kubeconfig 一.安装telepresence 1.安装 Telepresence Quickstart | Telepresence (1)brew install datawire/blackbird/telepresence 2.配置 目录kubectl 将使用默认的 kubeconfig 文件:$HOME/.kube/config 创建文件夹&…...
今日思考(2) — 训练机器学习模型用GPU还是NUP更有优势(基于文心一言的回答)
前言 深度学习用GPU,强化学习用NPU。 1.训练深度学习模型,强化学习模型用NPU还是GPU更有优势 在训练深度学习模型时,GPU相比NPU有优势。GPU拥有更高的访存速度和更高的浮点运算能力,因此更适合深度学习中的大量训练数据、大量矩阵…...
8.3 C++ 定义并使用类
C/C语言是一种通用的编程语言,具有高效、灵活和可移植等特点。C语言主要用于系统编程,如操作系统、编译器、数据库等;C语言是C语言的扩展,增加了面向对象编程的特性,适用于大型软件系统、图形用户界面、嵌入式系统等。…...
Git学习笔记——超详细
Git笔记 安装git: apt install git 创建版本库: git init 添加文件到版本库: git add 文件 提交文件到仓库: git commit -m “注释” 查看仓库当前的状态信息: git status 查看修改内容和之前版本的区别&am…...
Locust负载测试工具实操
本中介绍如何使用Locust为开发的服务/网站执行负载测试。 Locust 是一个开源负载测试工具,可以通过 Python 代码构造来定义用户行为,避免混乱的 UI 和臃肿的 XML 配置。 步骤 设置Locust。 在简单的 HTTP 服务上模拟基本负载测试。 准备条件 Python…...
深入浅出Asp.Net Core MVC应用开发系列-AspNetCore中的日志记录
ASP.NET Core 是一个跨平台的开源框架,用于在 Windows、macOS 或 Linux 上生成基于云的新式 Web 应用。 ASP.NET Core 中的日志记录 .NET 通过 ILogger API 支持高性能结构化日志记录,以帮助监视应用程序行为和诊断问题。 可以通过配置不同的记录提供程…...
label-studio的使用教程(导入本地路径)
文章目录 1. 准备环境2. 脚本启动2.1 Windows2.2 Linux 3. 安装label-studio机器学习后端3.1 pip安装(推荐)3.2 GitHub仓库安装 4. 后端配置4.1 yolo环境4.2 引入后端模型4.3 修改脚本4.4 启动后端 5. 标注工程5.1 创建工程5.2 配置图片路径5.3 配置工程类型标签5.4 配置模型5.…...
376. Wiggle Subsequence
376. Wiggle Subsequence 代码 class Solution { public:int wiggleMaxLength(vector<int>& nums) {int n nums.size();int res 1;int prediff 0;int curdiff 0;for(int i 0;i < n-1;i){curdiff nums[i1] - nums[i];if( (prediff > 0 && curdif…...
postgresql|数据库|只读用户的创建和删除(备忘)
CREATE USER read_only WITH PASSWORD 密码 -- 连接到xxx数据库 \c xxx -- 授予对xxx数据库的只读权限 GRANT CONNECT ON DATABASE xxx TO read_only; GRANT USAGE ON SCHEMA public TO read_only; GRANT SELECT ON ALL TABLES IN SCHEMA public TO read_only; GRANT EXECUTE O…...
【SQL学习笔记1】增删改查+多表连接全解析(内附SQL免费在线练习工具)
可以使用Sqliteviz这个网站免费编写sql语句,它能够让用户直接在浏览器内练习SQL的语法,不需要安装任何软件。 链接如下: sqliteviz 注意: 在转写SQL语法时,关键字之间有一个特定的顺序,这个顺序会影响到…...
苍穹外卖--缓存菜品
1.问题说明 用户端小程序展示的菜品数据都是通过查询数据库获得,如果用户端访问量比较大,数据库访问压力随之增大 2.实现思路 通过Redis来缓存菜品数据,减少数据库查询操作。 缓存逻辑分析: ①每个分类下的菜品保持一份缓存数据…...
鸿蒙中用HarmonyOS SDK应用服务 HarmonyOS5开发一个生活电费的缴纳和查询小程序
一、项目初始化与配置 1. 创建项目 ohpm init harmony/utility-payment-app 2. 配置权限 // module.json5 {"requestPermissions": [{"name": "ohos.permission.INTERNET"},{"name": "ohos.permission.GET_NETWORK_INFO"…...
第 86 场周赛:矩阵中的幻方、钥匙和房间、将数组拆分成斐波那契序列、猜猜这个单词
Q1、[中等] 矩阵中的幻方 1、题目描述 3 x 3 的幻方是一个填充有 从 1 到 9 的不同数字的 3 x 3 矩阵,其中每行,每列以及两条对角线上的各数之和都相等。 给定一个由整数组成的row x col 的 grid,其中有多少个 3 3 的 “幻方” 子矩阵&am…...
全面解析各类VPN技术:GRE、IPsec、L2TP、SSL与MPLS VPN对比
目录 引言 VPN技术概述 GRE VPN 3.1 GRE封装结构 3.2 GRE的应用场景 GRE over IPsec 4.1 GRE over IPsec封装结构 4.2 为什么使用GRE over IPsec? IPsec VPN 5.1 IPsec传输模式(Transport Mode) 5.2 IPsec隧道模式(Tunne…...
【开发技术】.Net使用FFmpeg视频特定帧上绘制内容
目录 一、目的 二、解决方案 2.1 什么是FFmpeg 2.2 FFmpeg主要功能 2.3 使用Xabe.FFmpeg调用FFmpeg功能 2.4 使用 FFmpeg 的 drawbox 滤镜来绘制 ROI 三、总结 一、目的 当前市场上有很多目标检测智能识别的相关算法,当前调用一个医疗行业的AI识别算法后返回…...
