当前位置: 首页 > news >正文

李沐动手学习深度学习——4.2练习

1. 在所有其他参数保持不变的情况下,更改超参数num_hiddens的值,并查看此超参数的变化对结果有何影响。确定此超参数的最佳值。

通过改变隐藏层的数量,导致就是函数拟合复杂度下降,隐藏层过多可能导致过拟合,而过少导致欠拟合。
我们将层数改为128可得:
在这里插入图片描述

2. 尝试添加更多的隐藏层,并查看它对结果有何影响。

过拟合,导致测试机精确度下降。

3. 改变学习速率会如何影响结果?保持模型架构和其他超参数(包括轮数)不变,学习率设置为多少会带来最好的结果?

过高的学习率导致,梯度跨度过大,使得降低不到对应的驻点。
过低的学习率导致训练缓慢,需要增加epoch。
在训练轮数不变的情况下,我们可以通过for 设置不同的学习率找出最合适的学习率。一般来说设置为0.01或者0.1足以

4. 通过对所有超参数(学习率、轮数、隐藏层数、每层的隐藏单元数)进行联合优化,可以得到的最佳结果是什么?

跑了一次学习率lr=0.01的情况:
在这里插入图片描述

需要大量的训练,但是目前我训练结果是学习率lr=0.1、轮数是num_epochs=10,隐藏层数为1,隐藏层数单元num_hiddens=128。

5. 描述为什么涉及多个超参数更具挑战性。

因为组合的情况更多,当层数越多时,训练时间也更多,这玩意就是炼丹了,看你自己的GPU还有时间、运气。

6. 如果想要构建多个超参数的搜索方法,请想出一个聪明的策略。

套用for 循环暴力破解,时间上肯定慢的要死,我们可以先固定其他变量,挑选一个变量寻找最优解,以此类推对所有的超参数这样使用,但是这种做法肯定不是最优的,只是能够较好的找出比较好的超参数。

由于学校穷逼所以没有闲置GPU服务器,所有的模型只能在colab上进行运行,其中遇到了d2l的版本对应问题,所以对于d2l.train_ch3跑不起来,只能使用自写进行替代如下:

import torch.nn
from d2l import torch as d2l
from IPython import displayclass Accumulator:"""在n个变量上累加"""def __init__(self, n):self.data = [0.0] * n       # 创建一个长度为 n 的列表,初始化所有元素为0.0。def add(self, *args):           # 累加self.data = [a + float(b) for a, b in zip(self.data, args)]def reset(self):                # 重置累加器的状态,将所有元素重置为0.0self.data = [0.0] * len(self.data)def __getitem__(self, idx):     # 获取所有数据return self.data[idx]def accuracy(y_hat, y):"""计算正确的数量:param y_hat::param y::return:"""if len(y_hat.shape) > 1 and y_hat.shape[1] > 1:y_hat = y_hat.argmax(axis=1)            # 在每行中找到最大值的索引,以确定每个样本的预测类别cmp = y_hat.type(y.dtype) == yreturn float(cmp.type(y.dtype).sum())def evaluate_accuracy(net, data_iter):"""计算指定数据集的精度:param net::param data_iter::return:"""if isinstance(net, torch.nn.Module):net.eval()                  # 通常会关闭一些在训练时启用的行为metric = Accumulator(2)with torch.no_grad():for X, y in data_iter:metric.add(accuracy(net(X), y), y.numel())return metric[0] / metric[1]class Animator:"""在动画中绘制数据"""def __init__(self, xlabel=None, ylabel=None, legend=None, xlim=None,ylim=None, xscale='linear', yscale='linear',fmts=('-', 'm--', 'g-', 'r:'), nrows=1, ncols=1,figsize=(3.5, 2.5)):# 增量的绘制多条线if legend is None:legend = []d2l.use_svg_display()self.fig, self.axes = d2l.plt.subplots(nrows, ncols, figsize=figsize)if nrows * ncols == 1:self.axes = [self.axes, ]# 使用lambda函数捕获参数self.config_axes = lambda: d2l.set_axes(self.axes[0], xlabel, ylabel, xlim, ylim, xscale, yscale, legend)self.X, self.Y, self.fmts = None, None, fmtsdef add(self, x, y):"""向图表中添加多个数据点:param x::param y::return:"""if not hasattr(y, "__len__"):y = [y]n = len(y)if not hasattr(x, "__len__"):x = [x] * nif not self.X:self.X = [[] for _ in range(n)]if not self.Y:self.Y = [[] for _ in range(n)]for i, (a, b) in enumerate(zip(x, y)):if a is not None and b is not None:self.X[i].append(a)self.Y[i].append(b)self.axes[0].cla()for x, y, fmt in zip(self.X, self.Y, self.fmts):self.axes[0].plot(x, y, fmt)self.config_axes()display.display(self.fig)display.clear_output(wait=True)def train_epoch_ch3(net, train_iter, loss, updater):"""训练模型一轮:param net:是要训练的神经网络模型:param train_iter:是训练数据的数据迭代器,用于遍历训练数据集:param loss:是用于计算损失的损失函数:param updater:是用于更新模型参数的优化器:return:"""if isinstance(net, torch.nn.Module):  # 用于检查一个对象是否属于指定的类(或类的子类)或数据类型。net.train()# 训练损失总和, 训练准确总和, 样本数metric = Accumulator(3)for X, y in train_iter:  # 计算梯度并更新参数y_hat = net(X)l = loss(y_hat, y)if isinstance(updater, torch.optim.Optimizer):  # 用于检查一个对象是否属于指定的类(或类的子类)或数据类型。# 使用pytorch内置的优化器和损失函数updater.zero_grad()l.mean().backward()  # 方法用于计算损失的平均值updater.step()else:# 使用定制(自定义)的优化器和损失函数l.sum().backward()updater(X.shape())metric.add(float(l.sum()), accuracy(y_hat, y), y.numel())# 返回训练损失和训练精度return metric[0] / metric[2], metric[1] / metric[2]def train_ch3(net, train_iter, test_iter, loss, num_epochs, updater):"""训练模型():param net::param train_iter::param test_iter::param loss::param num_epochs::param updater::return:"""animator = Animator(xlabel='epoch', xlim=[1, num_epochs], ylim=[0.3, 0.9],legend=['train loss', 'train acc', 'test acc'])for epoch in range(num_epochs):trans_metrics = train_epoch_ch3(net, train_iter, loss, updater)test_acc = evaluate_accuracy(net, test_iter)animator.add(epoch + 1, trans_metrics + (test_acc,))train_loss, train_acc = trans_metricsprint(trans_metrics)def predict_ch3(net, test_iter, n=6):"""进行预测:param net::param test_iter::param n::return:"""global X, yfor X, y in test_iter:breaktrues = d2l.get_fashion_mnist_labels(y)preds = d2l.get_fashion_mnist_labels(net(X).argmax(axis=1))titles = [true + "\n" + pred for true, pred in zip(trues, preds)]d2l.show_images(X[0:n].reshape((n, 28, 28)), 1, n, titles=titles[0:n])d2l.plt.show()

相关文章:

李沐动手学习深度学习——4.2练习

1. 在所有其他参数保持不变的情况下,更改超参数num_hiddens的值,并查看此超参数的变化对结果有何影响。确定此超参数的最佳值。 通过改变隐藏层的数量,导致就是函数拟合复杂度下降,隐藏层过多可能导致过拟合,而过少导…...

CYQ.Data 支持 DaMeng 达梦数据库

DaMeng 达梦数据库介绍: 达梦数据库(DMDB)是中国自主研发的关系型数据库管理系统,由达梦科技股份有限公司开发。 达梦数据库提供了企业级的数据库解决方案,广泛应用于金融、电信、政府、制造等行业领域。 达梦数据库具有以下特点和优势: 高性能:具备高性能的并发处理…...

计网面试题整理上

1. 计算机网络的各层协议及作用? 计算机网络体系可以大致分为一下三种,OSI七层模型、TCP/IP四层模型和五层模型。 OSI七层模型:大而全,但是比较复杂、而且是先有了理论模型,没有实际应用。TCP/IP四层模型&#xff1a…...

code: 500 ] This subject is anonymous - it does not have any identifying

项目场景: 相关背景: 使用idea 开发java 项目,前端页面请求 页面中相关的接口时,idea 控制台有报错信息出现,前端请求失败。 问题描述 问题: 使用idea 开发java 项目,前端页面请求 页面中相…...

FC-AE-1553 协议

FC-AE-1553 协议 MIL-STD-1553B总线协议总线结构字格式消息传输方式 FC协议FC协议栈拓扑结构服务类型帧/序列/交换FC帧格式 FC-AE-1553网络构成帧类型命令帧状态帧数据帧 Information UnitsNC1NC2NC3-4NC5-7NT1-7 传输模式1. NC-NT2. NT-NC3. NT-NT4. 无数据字的模式命令5. 带数…...

代码随想录算法训练营day38|理论基础、509. 斐波那契数、70. 爬楼梯、746. 使用最小花费爬楼梯

理论基础 代码随想录 视频:从此再也不怕动态规划了,动态规划解题方法论大曝光 !| 理论基础 |力扣刷题总结| 动态规划入门_哔哩哔哩_bilibili 动态规划:如果某一问题有很多重叠子问题,使用动态规划是最有效的。所以动态…...

夫妻一方名下股权到底归谁?

生效判决摘要:1.夫妻一方在婚姻关系存续期间投资的收益,为夫妻的共同财产,归夫妻共同所有,但是并不能据此否定股权本身可能成为夫妻共同财产。婚姻关系存续期间登记在配偶一方名下的股权能否成为夫妻共同财产,可由司法…...

git根据文件改动将文件自动添加到缓冲区

你需要修改以下脚本中的 use_cca: false 部分 #!/bin/bash# 获取所有已修改但未暂存的文件 files$(git diff --name-only)for file in $files; do# 检查文件中是否存在"use_cca: false"if grep -q "use_cca: false" "$file"; thenecho "Ad…...

SystemVerilog Constants、Processes

SystemVerilog提供了三种类型的精化时间常数: •参数:与最初的Verilog标准相同,可以以相同的方式使用。 •localparameter:与参数类似,但不能被上层覆盖模块。 •specparam:用于指定延迟和定时值&#x…...

交易平台开发:构建安全/高效/用户友好的在线交易生态圈

在数字化浪潮的推动下,农产品现货大宗商品撮合交易平台已成为连接全球买家与卖家的核心枢纽。随着电子商务的飞速发展,一个安全、高效、用户友好的交易平台对于促进交易、提升用户体验和增加用户黏性至关重要。本文将深入探讨交易平台开发的关键要素&…...

Linux系统之部署复古游戏平台

Linux系统之部署复古游戏平台 前言一、项目介绍1.1 项目简介1.2 项目特点1.3 游戏平台介绍二、本次实践介绍二、本地环境介绍2.1 本地环境规划2.2 本次实践介绍三、本地环境检查3.1 安装Docker环境3.2 检查Docker服务状态3.3 检查Docker版本3.4 检查docker compose 版本四、构建…...

开源计算机视觉库opencv-python详解

开源计算机视觉库opencv-python详解 OpenCV-Python的核心功能:安装OpenCV-Python:使用OpenCV-Python的基本步骤:OpenCV-Python的高级应用:注意事项:OpenCV-Python的高级应用示例:1. 人脸识别2. 目标跟踪3. …...

Vue开发实例(十)Tabs标签页打开、关闭与路由之间的关系

创建标签页 一、创建标签页二、点击菜单展示新标签页1、将标签数据作为全局使用2、菜单点击增加标签页3、处理重复标签4、关闭标签页 三、点击标签页操作问题1:点击标签页选中菜单进行高亮展示问题2:点击标签页路由也要跳转 四、解决bug 先展示最终效果 …...

基于51单片机的智能火灾报警系统

基于51单片机的智能火灾报警系统 摘要: 本文提出了一种基于51单片机的智能火灾报警系统。该系统采用烟雾传感器和温度传感器来检测火灾的发生,并通过单片机进行数据处理和报警控制。此外,该系统还具有无线通信功能,可以实时将火灾…...

【数据结构】堆的TopK问题

大家好,我是苏貝,本篇博客带大家了解堆的TopK问题,如果你觉得我写的还不错的话,可以给我一个赞👍吗,感谢❤️ 目录 一. 前言二. TopK三. 代码 一. 前言 TOP-K问题:即求数据结合中前K个最大的元…...

Vue后台管理系统笔记-01

npm(Node Package Manager)和 yarn 是两个常用的包管理工具,用于在 Node.js 项目中安装、管理和更新依赖项。它们有以下几个区别: 性能和速度:在包的安装和下载方面,yarn 通常比 npm 更快速。yarn 使用了并…...

飞天使-学以致用-devops知识点3-安装jenkins

文章目录 构建带maven环境的jenkins 镜像安装jenkinsjenkins yaml 文件安装插件jenkins 配置k8s创建用户凭证 构建带maven环境的jenkins 镜像 # 构建带 maven 环境的 jenkins 镜像 docker build -t 192.168.113.122:8858/library/jenkins-maven:jdk-11 .# 登录 harbor docker …...

08、MongoDB -- MongoDB 的 集合关联($lookup 和 DBRef 实现集合关联)

目录 MongoDB 的 集合关联演示前提:登录单机模式的 mongodb 服务器命令登录【test】数据库的 mongodb 客户端命令登录【admin】数据库的 mongodb 客户端命令 SQL 术语 与 Mongodb 的对应关系使用 $lookup 实现集合关联语法格式添加测试数据1、查询出订单数量大于6&a…...

前方高能,又一波Smartbi签约喜报来袭

近期,交通银行、厦门国际银行、中原农业保险、江苏中天科技等多家知名企业签约Smartbi,携手Smartbi实现数据驱动业务新增长。 Smartbi数10年专注于商业智能BI与大数据分析软件与服务,为各行各业提供提供一站式商业智能平台(PaaS&a…...

蓝桥杯倒计时 41天 - 二分答案-最大通过数-妮妮的月饼工厂

最大通过数 思路&#xff1a;假设左边能通过 x 关&#xff0c;右边能通过 y 关&#xff0c;x∈[0,n]&#xff0c;通过二分&#xff0c;在前缀和中枚举右边通过的关卡数&#xff0c;保存 xy 的最大值。 #include<bits/stdc.h> using namespace std; typedef long long ll…...

Vue3 + Element Plus + TypeScript中el-transfer穿梭框组件使用详解及示例

使用详解 Element Plus 的 el-transfer 组件是一个强大的穿梭框组件&#xff0c;常用于在两个集合之间进行数据转移&#xff0c;如权限分配、数据选择等场景。下面我将详细介绍其用法并提供一个完整示例。 核心特性与用法 基本属性 v-model&#xff1a;绑定右侧列表的值&…...

DAY 47

三、通道注意力 3.1 通道注意力的定义 # 新增&#xff1a;通道注意力模块&#xff08;SE模块&#xff09; class ChannelAttention(nn.Module):"""通道注意力模块(Squeeze-and-Excitation)"""def __init__(self, in_channels, reduction_rat…...

sqlserver 根据指定字符 解析拼接字符串

DECLARE LotNo NVARCHAR(50)A,B,C DECLARE xml XML ( SELECT <x> REPLACE(LotNo, ,, </x><x>) </x> ) DECLARE ErrorCode NVARCHAR(50) -- 提取 XML 中的值 SELECT value x.value(., VARCHAR(MAX))…...

【JavaSE】绘图与事件入门学习笔记

-Java绘图坐标体系 坐标体系-介绍 坐标原点位于左上角&#xff0c;以像素为单位。 在Java坐标系中,第一个是x坐标,表示当前位置为水平方向&#xff0c;距离坐标原点x个像素;第二个是y坐标&#xff0c;表示当前位置为垂直方向&#xff0c;距离坐标原点y个像素。 坐标体系-像素 …...

聊一聊接口测试的意义有哪些?

目录 一、隔离性 & 早期测试 二、保障系统集成质量 三、验证业务逻辑的核心层 四、提升测试效率与覆盖度 五、系统稳定性的守护者 六、驱动团队协作与契约管理 七、性能与扩展性的前置评估 八、持续交付的核心支撑 接口测试的意义可以从四个维度展开&#xff0c;首…...

OPenCV CUDA模块图像处理-----对图像执行 均值漂移滤波(Mean Shift Filtering)函数meanShiftFiltering()

操作系统&#xff1a;ubuntu22.04 OpenCV版本&#xff1a;OpenCV4.9 IDE:Visual Studio Code 编程语言&#xff1a;C11 算法描述 在 GPU 上对图像执行 均值漂移滤波&#xff08;Mean Shift Filtering&#xff09;&#xff0c;用于图像分割或平滑处理。 该函数将输入图像中的…...

基于 TAPD 进行项目管理

起因 自己写了个小工具&#xff0c;仓库用的Github。之前在用markdown进行需求管理&#xff0c;现在随着功能的增加&#xff0c;感觉有点难以管理了&#xff0c;所以用TAPD这个工具进行需求、Bug管理。 操作流程 注册 TAPD&#xff0c;需要提供一个企业名新建一个项目&#…...

打手机检测算法AI智能分析网关V4守护公共/工业/医疗等多场景安全应用

一、方案背景​ 在现代生产与生活场景中&#xff0c;如工厂高危作业区、医院手术室、公共场景等&#xff0c;人员违规打手机的行为潜藏着巨大风险。传统依靠人工巡查的监管方式&#xff0c;存在效率低、覆盖面不足、判断主观性强等问题&#xff0c;难以满足对人员打手机行为精…...

第7篇:中间件全链路监控与 SQL 性能分析实践

7.1 章节导读 在构建数据库中间件的过程中&#xff0c;可观测性 和 性能分析 是保障系统稳定性与可维护性的核心能力。 特别是在复杂分布式场景中&#xff0c;必须做到&#xff1a; &#x1f50d; 追踪每一条 SQL 的生命周期&#xff08;从入口到数据库执行&#xff09;&#…...

如何应对敏捷转型中的团队阻力

应对敏捷转型中的团队阻力需要明确沟通敏捷转型目的、提升团队参与感、提供充分的培训与支持、逐步推进敏捷实践、建立清晰的奖励和反馈机制。其中&#xff0c;明确沟通敏捷转型目的尤为关键&#xff0c;团队成员只有清晰理解转型背后的原因和利益&#xff0c;才能降低对变化的…...