当前位置: 首页 > news >正文

Life Long Learning(李宏毅)机器学习 2023 Spring HW14 (Boss Baseline)

1. 终身学习简介

神经网络的典型应用场景是,我们有一个固定的数据集,在其上训练并获得模型参数,然后将模型应用于特定任务而无需进一步更改模型参数。

然而,在许多实际工程应用中,常见的情况是系统可以不断地获取新数据,例如 Web 应用程序中的新用户数据或自动驾驶中的新驾驶数据。 这些新数据需要被纳入训练集以增强模型的性能。 这被称为终身学习。

Life Long Learning (source: https://speech.ee.ntu.edu.tw/~hylee/ml/ml2021-course-data/life_v2.pdf)

人工智能终身学习面临哪些挑战?似乎通过不断更新其数据和相应的网络参数就可以实现终身学习。然而,事实并非如此简单。一个关键障碍是灾难性遗忘。

如示例所示,在任务 1 上训练的神经网络达到了 90% 的准确率。在随后对任务 2 进行训练后,其在任务 2 上的性能提高到 97%。然而,该网络在原始任务 1 上的性能急剧下降至 80%。这表明学习新信息会导致网络“遗忘”先前学习的知识,阻碍其随着时间的推移积累知识的能力。

catastrophic forgetting (source: https://speech.ee.ntu.edu.tw/~hylee/ml/ml2021-course-data/life_v2.pdf)

此外,这种遗忘并不是由于神经网络缺乏同时处理两个任务的能力。如附加图表所示,如果我们同时对两个任务进行训练(交替训练来自任务 1 和任务 2 的数据),它在两个任务上都达到了很高的准确率(任务 1 上为 89%,任务 2 上为 98%)。 这表明网络确实具有学习两个任务的能力;问题在于学习过程的顺序性,其中学习新任务会覆盖先前学习的知识。

network capacity (source: https://speech.ee.ntu.edu.tw/~hylee/ml/ml2021-course-data/life_v2.pdf)

灾难性遗忘背后的核心问题在于不同任务之间的优化方向存在冲突。当神经网络在一个任务上进行训练时,它会调整其参数(用 θ 表示)以最小化该特定任务的损失函数。这个过程会找到一组对于该任务表现良好的最优参数( θ∗ )。然而,当网络随后在新的任务上进行训练时,它会再次调整其参数以最小化新任务的损失。这个新的优化过程通常会将参数从先前为原始任务找到的最优解中拉开,有效地“覆盖”了为第一个任务学习到的知识。当任务之间差异较大时,这种现象尤为明显。

why catastrophic forgetting (source: https://speech.ee.ntu.edu.tw/~hylee/ml/ml2021-course-data/life_v2.pdf)

鉴于这一挑战,我们如何缓解灾难性遗忘?这篇综合综述论文概述了当前最先进的方法(截至 2024 年),将其分为五个关键方法。

A Comprehensive Survey of Continual Learning: Theory, Method and Application​icon-default.png?t=O83Ahttp://arxiv.org/abs/2302.00487

Continual Learning Method (source: https://arxiv.org/abs/2302.00487)

基于正则化的方法中一个值得注意的技术是选择性突触可塑性。“突触权重”表示神经元(节点)之间连接的强度,“可塑性”被用作对生物学概念神经可塑性的类比,后者描述了大脑改变神经元之间连接强度(突触权重)的能力。选择性可塑性是指神经网络期望具有的特性,即选择性地调节其整个结构中单个突触的可塑性。

Selective Synaptic Plasticity (source: https://speech.ee.ntu.edu.tw/~hylee/ml/ml2021-course-data/life_v2.pdf)

如图所示,参数 b 控制着可塑性的水平。在一个极端情况下,当 b 等于 0 时,正则化项消失,有效地消除了对参数更新的任何约束。这种约束的缺失导致了灾难性遗忘,其中学习新任务完全覆盖了先前获得的知识。在另一个极端情况下,当 b 接近无穷大时,正则化变得无限强,阻止了对参数的任何重大更改。在这种情况下,模型变得“固执”,无法学习任何新东西,只能停留在其初始知识上。

how b affects the performance (source: https://speech.ee.ntu.edu.tw/~hylee/ml/ml2021-course-data/life_v2.pdf)

所以,关键问题是如何计算这种“可塑性”?李宏毅教授在课程中推荐了以下论文:

Elastic Weight Consolidation (EWC) (弹性权重巩固 ):

Overcoming catastrophic forgetting in neural networksicon-default.png?t=O83Ahttp://​arxiv.org/abs/1612.00796

EWC 使用 Fisher 信息矩阵 (FFF) 来衡量每个权重的重要性,并在学习新任务时惩罚对关键权重的更改:

L_{\text{total}} = L_{\text{task}} + \frac{\lambda}{2} \sum_i F_i (\theta_i - \theta^*_{i})^2 \\

  • L_{\text{task}}​: 任务特定的损失

  • F_i : 权重 \theta_i 的 Fisher 信息

  • \theta^*_{i} ​: 来自先前任务的最优权重

  • \lambda: 正则化强度

F 的定义如下所示:

F = [ \nabla \log(p(y_n | x_n, \theta^{*})) \nabla \log(p(y_n | x_n, \theta^{*}))^T ] \\

我们仅取矩阵的对角线值来近似每个参数的 F_i .

相应的代码实现为:

# EWC
class ewc(object):"""@article{kirkpatrick2017overcoming,title={Overcoming catastrophic forgetting in neural networks},author={Kirkpatrick, James and Pascanu, Razvan and Rabinowitz, Neil and Veness, Joel and Desjardins, Guillaume and Rusu, Andrei A and Milan, Kieran and Quan, John and Ramalho, Tiago and Grabska-Barwinska, Agnieszka and others},journal={Proceedings of the national academy of sciences},year={2017},url={https://arxiv.org/abs/1612.00796}}"""def __init__(self, model, dataloader, device, prev_guards=[None]):self.model = modelself.dataloader = dataloaderself.device = device# extract all parameters in modelsself.params = {n: p for n, p in self.model.named_parameters() if p.requires_grad}# initialize parametersself.p_old = {}# save previous guardsself.previous_guards_list = prev_guards# generate Fisher (F) matrix for EWCself._precision_matrices = self._calculate_importance()# keep the old parameter in self.p_oldfor n, p in self.params.items():self.p_old[n] = p.clone().detach()def _calculate_importance(self):precision_matrices = {}# initialize Fisher (F) matrix(all fill zero)and add previous guardsfor n, p in self.params.items():precision_matrices[n] = p.clone().detach().fill_(0)for i in range(len(self.previous_guards_list)):if self.previous_guards_list[i]:precision_matrices[n] += self.previous_guards_list[i][n]self.model.eval()if self.dataloader is not None:number_data = len(self.dataloader)for data in self.dataloader:self.model.zero_grad()# get image datainput = data[0].to(self.device)# image data forward modeloutput = self.model(input)# Simply use groud truth label of dataset.label = data[1].to(self.device)# generate Fisher(F) matrix for EWCloss = F.nll_loss(F.log_softmax(output, dim=1), label)loss.backward()for n, p in self.model.named_parameters():# get the gradient of each parameter and square it, then average it in all validation set.precision_matrices[n].data += p.grad.data ** 2 / number_dataprecision_matrices = {n: p for n, p in precision_matrices.items()}return precision_matricesdef penalty(self, model: nn.Module):loss = 0for n, p in model.named_parameters():# generate the final regularization term by the ewc weight (self._precision_matrices[n]) and the square of weight difference ((p - self.p_old[n]) ** 2)._loss = self._precision_matrices[n] * (p - self.p_old[n]) ** 2loss += _loss.sum()return lossdef update(self, model):# do nothingreturn

Synaptic Intelligence (SI) (突触智能):

Continual Learning Through Synaptic Intelligenceicon-default.png?t=O83Ahttp://arxiv.org/abs/1703.04200

SI 通过计算参数 i 的重要权重 \Omega_i 来衡量其重要性,计算方法如下:

\Omega_i^{\text{new}} \leftarrow \Omega_i^{\text{old}} + \frac{W_i}{(\theta_i - \theta_i^*)^2 + \epsilon} \\

其中:

  • \Omega_i^{\text{new}} :参数 i 的更新重要权重

  • \Omega_i^{\text{old}}​:来自先前任务的重要权重

  • W_i​:累积梯度信息

  • \theta_i - \theta_i^* ​:任务期间的参数变化

  • \epsilon : 防止出现零除数的小正数常量

然后,正则化惩罚项 LSI 被计算为:

相关文章:

Life Long Learning(李宏毅)机器学习 2023 Spring HW14 (Boss Baseline)

1. 终身学习简介 神经网络的典型应用场景是,我们有一个固定的数据集,在其上训练并获得模型参数,然后将模型应用于特定任务而无需进一步更改模型参数。 然而,在许多实际工程应用中,常见的情况是系统可以不断地获取新数据,例如 Web 应用程序中的新用户数据或自动驾驶中的…...

libc.so.6不兼容

1、查看电脑所有libc.so.6 daviddavid-Shangqi-X4270:~/MySoft/ubuntusoft$ locate libc.so.6 /home/david/MySoft/ubuntusoft/EXEApp/libc.so.6 /home/david/MySoft/ubuntusoft/EXEApp_TEST/libc.so.6 /home/david/MySoft/ubuntusoft/RTMG_APP/libc.so.6 /home/david/MySoft/…...

树的模拟实现

一.链式前向星 所谓链式前向星,就是用链表的方式实现树。其中的链表是用数组模拟实现的链表。 首先我们需要创建一个足够大的数组h,作为所有结点的哨兵位。创建两个足够大的数组e和ne,一个作为数据域,一个作为指针域。创建一个变…...

AsyncOperation.allowSceneActivation导致异步加载卡死

先看这段代码,有个诡异的问题,不确定是不是bug public class Test : MonoBehaviour {void Start(){StartCoroutine(LoadScene(Ego.LoadingLevel));}IEnumerator LoadScene(string sceneName){LoadingUI.UpdateProgress(0.9f);yield return new WaitForS…...

如何搭建 Vue.js 开源项目的 CI/CD 流水线

网罗开发 (小红书、快手、视频号同名) 大家好,我是 展菲,目前在上市企业从事人工智能项目研发管理工作,平时热衷于分享各种编程领域的软硬技能知识以及前沿技术,包括iOS、前端、Harmony OS、Java、Python等…...

单通道串口服务器(三格电子)

一、产品介绍 1.1 功能简介 SG-TCP232-110 是一款用来进行串口数据和网口数据转换的设备。解决普通 串口设备在 Internet 上的联网问题。 设备的串口部分提供一个 232 接口和一个 485 接口,两个接口内部连接,同 时只能使用一个口工作。 设 备 的网 口…...

【Excel/WPS】根据平均值,生成两列/多列指定范围的随机数/随机凑出两列数据

原理就是通过随机生成函数和平均值函数。 适用场景:在总体打分后,需要在小项中随机生成小分数 第一列:固定的平均值A2第二列: RANDBETWEEN(A2-10,A210)第三列:根据第二列用平均值函数算除 A2*2-B2这是随机值1的公式&am…...

使用网页版Jupyter Notebook和VScode打开.ipynb文件

目录 正文 1、网页版Jupyter Notebook查看 2、VScode查看 因为总是忘记查看文件的网址,收藏了但分类众多每次都找不到……当个记录吧(/捂脸哭)! 正文 此处以gitub中的某个仓库为例: https://github.com/INM-6/mu…...

记录一下vue2项目优化,虚拟列表vue-virtual-scroll-list处理10万条数据

文章目录 封装BrandPickerVirtual.vue组件页面使用组件属性 select下拉接口一次性返回10万条数据,页面卡死,如何优化??这里使用 分页 虚拟列表(vue-virtual-scroll-list),去模拟一个下拉的内容…...

CDA数据分析师一级经典错题知识点总结(5)

1、数值型缺失值用中位数补充,分类数据用众数补充。 2、偏态系数>1就是高度偏,0.5到1是中度。 3、分布和检验 在 t检验之前进行 F检验的目的是确保 t检验的方差齐性假设成立。如果 F检验结果显示方差不相等,则需要切换到调整后的 t 检验…...

服务器、电脑和移动手机操作系统

一、服务器操作系统 1、Windows Server 开发商是微软公司。友好的用户界面、与微软生态系统的高度集成、提供了广泛的企业级功能(如Active Directory、DNS、DHCP服务等)。适合需要大量运行Microsoft应用和服务的企业环境,如SQL Server等。经…...

深入解析 Flink 与 Spark 的性能差异

💖 欢迎来到我的博客! 非常高兴能在这里与您相遇。在这里,您不仅能获得有趣的技术分享,还能感受到轻松愉快的氛围。无论您是编程新手,还是资深开发者,都能在这里找到属于您的知识宝藏,学习和成长…...

如何在 Linux、MacOS 以及 Windows 中打开控制面板

控制面板不仅仅是一系列图标和菜单的集合;它是通往优化个人计算体验的大门。通过它,用户可以轻松调整从外观到性能的各种参数,确保他们的电脑能够完美地适应自己的需求。无论是想要提升系统安全性、管理硬件设备,还是简单地改变桌…...

微信小程序中 隐藏scroll-view 滚动条 网页中隐藏滚动条

在微信小程序中隐藏scroll-view的滚动条可以通过以下几种方法实现: 方法一:使用CSS隐藏滚动条 在小程序的样式文件中(如app.wxss或页面的.wxss文件),添加以下CSS代码来隐藏滚动条: scroll-view ::-webkit…...

Java 实现 Elasticsearch 查询当前索引全部数据

Java 实现 Elasticsearch 查询当前索引全部数据 需求背景通常情况Java 实现查询 Elasticsearch 全部数据写在最后 需求背景 通常情况下,Elasticsearch 为了提高查询效率,对于不指定分页查询条数的查询语句,默认会返回10条数据。那么这就会有…...

android刷机

android ota和img包下载地址: https://developers.google.com/android/images?hlzh-cn android启动过程 线刷 格式:ota格式 模式:recovery 优点:方便、简单,刷机方法通用,不会破坏手机底层数据&#xff0…...

【25考研】西南交通大学计算机复试重点及经验分享!

一、复试内容 上机考试:考试题型为编程上机考试,使用 C 语言,考试时长包括 15 分钟模拟考试和 120 分钟正式考试,考试内容涵盖顺序结构、选择结构、循环结构、数组、指针、字符串处理、函数、递归、结构体、动态存储、链表等知识点…...

OpenCV相机标定与3D重建(49)将视差图(disparity map)重投影到三维空间中函数reprojectImageTo3D()的使用

操作系统:ubuntu22.04 OpenCV版本:OpenCV4.9 IDE:Visual Studio Code 编程语言:C11 算法描述 将视差图像重投影到3D空间。 cv::reprojectImageTo3D 是 OpenCV 库中的一个函数,用于将视差图(disparity map&#xff09…...

学习HTTP Range

HTTP Range 请求 一种通过指定文件字节范围加载部分数据的技术,广泛用于断点续传、流媒体播放、分布式文件系统的数据分片加载等场景。 请求格式-在请求头中使用 Range 字段指定所需的字节范围 Range: bytes0-1023// bytes0-1023:表示请求文件的第 0 …...

大语言模型训练的数据集从哪里来?

继续上篇文章的内容说说大语言模型预训练的数据集从哪里来以及为什么互联网上的数据已经被耗尽这个说法并不专业,再谈谈大语言模型预训练数据集的优化思路。 1. GPT2使用的数据集是WebText,该数据集大概40GB,由OpenAI创建,主要内…...

Webpack和Vite的区别

一、构建速度方面 webpack默认是将所有模块都统一打包成一个js文件,每次修改都会重写构建整个项目,自上而下串行执行,所以会随着项目规模的增大,导致其构建打包速度会越来越慢 vite只会对修改过的模块进行重构,构建速…...

【再谈设计模式】模板方法模式 - 算法骨架的构建者

一、引言 在软件工程、软件开发过程中,我们经常会遇到一些算法或者业务逻辑具有固定的流程步骤,但其中个别步骤的实现可能会因具体情况而有所不同的情况。模板方法设计模式(Template Method Design Pattern)就为解决这类问题提供了…...

Bytebase 3.1.1 - 可定制的快捷访问首页

🚀 新功能 可定制的快捷访问首页。 支持查询 Redis 集群中所有节点。 赋予项目角色时,过期时间可以定义精确到秒级的时间点。 🔔 重大变更 移除 Database 消息里的实例角色信息。调用 GetInstance 或 ListInstanceRoles 以获取实例角色信息…...

Java阶段四04

第4章-第4节 一、知识点 CSRF、token、JWT 二、目标 理解什么是CSRF攻击以及如何防范 理解什么是token 理解什么是JWT 理解session验证和JWT验证的区别 学会使用JWT 三、内容分析 重点 理解什么是CSRF攻击以及如何防范 理解什么是token 理解什么是JWT 理解session验…...

B2C API安全警示:爬虫之外,潜藏更大风险挑战

在数字化时代,B2C(Business-to-Consumer)电子商务模式已成为企业连接消费者、推动业务增长的重要桥梁。而B2C API(应用程序编程接口)作为企业与消费者之间数据交互的桥梁,其安全性更是至关重要。然而&#…...

OCR文字识别—基于PP-OCR模型实现ONNX C++推理部署

概述 PaddleOCR 是一款基于 PaddlePaddle 深度学习平台的开源 OCR 工具。PP-OCR是PaddleOCR自研的实用的超轻量OCR系统。它是一个两阶段的OCR系统,其中文本检测算法选用DB,文本识别算法选用CRNN,并在检测和识别模块之间添加文本方向分类器&a…...

如何播放视频文件

文章目录 1. 概念介绍2. 使用方法2.1 实现步骤2.2 具体细节3. 示例代码4. 内容总结我们在上一章回中介绍了"如何获取文件类型"相关的内容,本章回中将介绍如何播放视频.闲话休提,让我们一起Talk Flutter吧。 1. 概念介绍 播放视频是我们常用的功能,不过Flutter官方…...

MySQL -- 约束

1. 数据库约束 数据库约束时关系型数据库的一个重要功能,主要的作用是保证数据的有效性,也可以理解为数据的正确性(数据本身是否正确,关联关系是否正确) 人工检查数据的完整性工作量非常大,在数据库中定义一些约束,那么数据在写入数据库的时候,就会帮我们做一些校验.并且约束一…...

php 使用simplexml_load_string转换xml数据格式失败

本文介绍如何使用php函数解析xml数据为数组。 <?php$a <xml><ToUserName><![CDATA[ww8b77afac71336111]]></ToUserName><FromUserName><![CDATA[sys]]></FromUserName><CreateTime>1736328669</CreateTime><Ms…...

net-http-transport 引发的句柄数(协程)泄漏问题

Reference 关于 Golang 中 http.Response.Body 未读取导致连接复用问题的一点研究https://manishrjain.com/must-close-golang-http-responsehttps://www.reddit.com/r/golang/comments/13fphyz/til_go_response_body_must_be_closed_even_if_you/?rdt35002https://medium.co…...