大模型 | NEFTune之引入随机噪声对大模型训练的收益
大模型 | NEFTune之引入随机噪声对大模型训练的收益
paper中提到,在模型foward过程中,对inputs_embedding增加适度的随机噪声,会带来显著的收益。
Paper: https://arxiv.org/pdf/2310.05914.pdf
Github: https://github.com/neelsjain/NEFTune
文章目录
- 大模型 | NEFTune之引入随机噪声对大模型训练的收益
- 理论
- 一. 实践方法
- 1.1 等待Hugging发布该功能
- 1.2 直接封装model
- 1.3 改写compute_loss
理论
核心是输入经过Embedding层后,再加入一个均匀分布的噪声,噪声的采样范围为 [ − α L d , α L d ] [-\frac{\alpha}{\sqrt{Ld}},\frac{\alpha}{\sqrt{Ld}}] [−Ldα,Ldα]之间,其中 α \alpha α为噪声超参,L为输入长度,d为Embedding层维度(即hidden维度)

在AlpacaEval榜单上,利用GPT4作为评分器,在多个数据上微调Llama2-7B模型,NEFTune方法相较于直接微调方法,均有显著提高。

可以缓解模型在指令微调阶段的过拟合现象,可以更好的利用预训练阶段的知识内容。
一. 实践方法
1.1 等待Hugging发布该功能
进度:等待hugging face正式发布此功能,2023-10-26
[10/17/2023] NEFTune has been intregrated into the Huggingface’s TRL (Transformer Reinforcement Learning) library. See Annoucement.
1.2 直接封装model
进度:直接对模型进行如下封装,原理是对model.embed_tokens.forward()进行改写,经实践,这种方法不管用,会报堆栈溢出的error。
from torch.nn import functional as Fdef NEFTune(model, noise_alpha=5)def noised_embed(orig_embed, noise_alpha):def new_func(x):# during training, we add noise to the embedding# during generation, we don't add noise to the embeddingif model.training:embed_init = orig_embed(x)dims = torch.tensor(embed_init.size(1) * embed_init.size(2))mag_norm = noise_alpha/torch.sqrt(dims)return embed_init + torch.zeros_like(embed_init).uniform_(-mag_norm, mag_norm)else:return orig_embed(x)return new_func##### NOTE: this is for a LLaMA model ##### ##### For a different model, you need to change the attribute path to the embedding #####model.base_model.model.model.embed_tokens.forward = noised_embed(model.base_model.model.model.embed_tokens, noise_alpha)return model
1.3 改写compute_loss
进度:loss能够正常计算,但optimzer会报错,可能与精度有关,尚未解决
由于损失函数是自己写的,因此尝试在model(**input)前,追加噪声代码。注意,原先传入model的是input_ids,而当下由于我们将inputs_embeds增加了噪声,因此传入model的将直接替换为inputs_embeds,代码如下
class TargetLMLossNeft(Loss):def __init__(self, ignore_index):super().__init__()self.ignore_index = ignore_indexself.loss_fn = nn.CrossEntropyLoss(ignore_index=ignore_index)def __call__(self, model, inputs, training_args, return_outputs=False):input_ids = inputs['input_ids'] # B x L [3, 964]attention_mask = inputs['attention_mask'] # B x L target_mask = inputs['target_mask'] # B x L### ----------------------------- add noise to embedsneftune_alpha = 5embed_device = model.base_model.model.model.embed_tokens.weight.deviceembeds_init = model.base_model.model.model.embed_tokens.forward(input_ids).to(embed_device) # 先forward一下, 变成B X L X hidden_state# embed_device = model.model.embed_tokens.weight.device# embeds_init = model.model.embed_tokens.forward(input_ids).to(embed_device)input_mask = attention_mask.to(embeds_init) # B x Linput_lengths = torch.sum(input_mask, 1) # B, 计算每个sample的实际长度noise_ = torch.zeros_like(embeds_init).uniform_(-1,1) # B X L X hidden_state, 且值域在[-1,1]正态分布delta = noise_ * input_mask.unsqueeze(2) # 追加一个维度,由B X L 变成 B X L X hidden_statedims = input_lengths * embeds_init.size(-1)mag = neftune_alpha / torch.sqrt(dims)delta = (delta * mag.view(-1, 1, 1)).detach() # B X L X hidden_stateinputs_embeds = delta + embeds_init### ----------------------------- add noise to embeds# 模型前馈预测, 原来传入的是input_ids,而现在需要直接将增加了noise的inputs_embeds传入# outputs = model(input_ids=input_ids, attention_mask=attention_mask, return_dict=True)outputs = model(inputs_embeds=inputs_embeds, attention_mask=attention_mask, return_dict=True)logits = outputs["logits"] if isinstance(outputs, dict) else outputs[0] # 正常应该是torch.float32#logits.requires_grad = True # 奇怪,为什么这里会默认为False, 难道是因为上边的detach()# 将labels中不属于target的部分,设为ignore_index,只计算target部分的losslabels = torch.where(target_mask == 1, input_ids, self.ignore_index)shift_logits = logits[..., :-1, :].contiguous()shift_labels = labels[..., 1:].contiguous()# Flatten the tokensloss = self.loss_fn(shift_logits.view(-1, shift_logits.size(-1)), shift_labels.view(-1)) # float32loss.requires_grad = Truereturn (loss, outputs) if return_outputs else loss
相关文章:
大模型 | NEFTune之引入随机噪声对大模型训练的收益
大模型 | NEFTune之引入随机噪声对大模型训练的收益 paper中提到,在模型foward过程中,对inputs_embedding增加适度的随机噪声,会带来显著的收益。 Paper: https://arxiv.org/pdf/2310.05914.pdf Github: https://github.com/neelsjain/NEFT…...
【开源】基于SpringBoot的高校学院网站的设计和实现
目录 一、摘要1.1 项目介绍1.2 项目录屏 二、功能模块2.1 学院院系模块2.2 竞赛报名模块2.3 教育教学模块2.4 招生就业模块2.5 实时信息模块 三、系统设计3.1 用例设计3.2 数据库设计3.2.1 学院院系表3.2.2 竞赛报名表3.2.3 教育教学表3.2.4 招生就业表3.2.5 实时信息表 四、系…...
什么是云原生?土生土长?
“云原生”(Cloud Native)是一种构建和运行应用程序的方法,这种方法充分利用了云计算的优势。云原生应用程序是为云环境设计的,通常是在容器中运行,并被设计为在微服务架构中运行,这使得它们能够快速扩展和…...
2011-2021年北大数字普惠金融指数数据(包括省市县)第四期
2011-2021年北大省市县数字普惠金融指数数据(第四期) 1、时间:2011-2021年 2、指标:index_aggregate、coverage_breadth、usage_depth、payment、insurance、monetary_fund、investment、credit、credit_investigation、digitiz…...
ch3_6多线程举例
作者丨billom 来源丨投稿 编辑丨GiantPandaCV 云端深度学习的服务的性能加速通常需要算法和工程的协同加速,需要模型推理和计算节点的融合,并保证整个“木桶”没有太明显的短板。 如何在满足时延前提下让算法工程师的服务的吞吐尽可能高,尽…...
javaEE -7(网络原理初识 --- 7000字)
一:网络初识 计算机的独立模式是指多台计算机在网络中相互独立运行,彼此之间不共享资源或信息。在早期,计算机主要采用独立模式,每台计算机都拥有自己的操作系统、应用程序和数据,它们之间没有直接的连接或通信。 在…...
新生儿弱视:原因、科普和注意事项
引言: 新生儿弱视,也被称为婴儿弱视或婴儿屈光不正,是一个在婴儿和幼儿时期非常重要的视觉问题。虽然它是一种潜在的视觉障碍,但早期的诊断和干预可以显著改善儿童的视觉发育。本文将科普新生儿弱视的原因,提供相关信…...
【机器学习可解释性】2.特征重要性排列
机器学习可解释性 1.模型洞察的价值2.特征重要性排列3.偏依赖图 ( partial dependence plots )4.SHAP Value5.SHAP Value 高级使用 正文 前言 你的模型认为哪些特征最重要? 介绍 我们可能会对模型提出的最基本的问题之一是:哪…...
机器学习之朴素贝叶斯
朴素贝叶斯: 也叫贝叶算法推断,建立在主管判断的基础上,不断地进行地修正。需要大量的计算。1、主观性强2、大量计算 贝叶斯定理:有先验概率和后验概率区别:假如出门堵车有两个因素:车太多与交通事故先验概…...
Python中if __name__ == ‘__main__‘,__init__和self 的解析
一、 if __name__ __main__ if __name__ __main__的意思是: 当.py文件被直接运行时,if __name__ __main__之下的代码块将被运行; 当.py文件以模块形式被导入时,if __name__ __main__之下的代码块不被运行。 1.1、一个 xxx.p…...
【Superset】自定义授权认证,接入内部系统二次开发
想要将内部系统认证与superset打通,必须要了解superset的认证体系。 Superset的认证体系 Superset的认证体系可以通过以下几种方式进行配置: 基于LDAP认证:Superset可以集成LDAP以验证用户身份。在这种情况下,Superset将根据LDAP…...
私有云:【1】ESXI的安装
私有云:【1】ESXI的安装 1、使用VMware Workstation创建虚拟机2、启动配置虚拟机3、登录ESXI管理台 1、使用VMware Workstation创建虚拟机 新建虚拟机 选择典型安装 稍后安装操作系统 选择VMware ESXI 选择虚拟机安装路径 硬盘设置300G或者更多 自定义硬件 内存和处…...
Mac怎么删除文件和软件?苹果电脑删除第三方软件方法
Mac删除程序这个话题为什么一直重复说或者太多人讨论呢?因为如果操作不当,可能会导致某些不好的影响。因为Mac电脑如果有太多无用的应用程序,很有可能会拖垮Mac系统的运行速度。或者如果因为删除不干净,导致残留文件积累在Mac电脑…...
【开题报告】基于微信小程序的旅游攻略分享平台的设计与实现
1.研究背景及意义 旅游已经成为现代人生活中重要的组成部分,人们越来越热衷于探索新的目的地和体验不同的文化。然而,对于旅游者来说,获取准确、可靠的旅游攻略信息并不容易。传统的旅游攻略书籍或网站往往无法提供实时、个性化的建议。因此…...
布隆过滤器(Bloom Filter)初学习
目录 1、布隆过滤器是什么 2、布隆过滤器的优缺点 3、使用场景 4、⭐基于Redis的布隆过滤器插件安装 4.1 下载布隆过滤器 4.2 创建文件夹并上传文件 4.3 安装gcc 4.4 解压RedisBloom压缩包 4.5 在解压好的文件夹下输入make 4.6 将编译的好的插件拷贝到docker redis容…...
“深入探讨操作系统和虚拟化技术“
目录 引言1.操作系统1.1.什么是操作系统1.2.常见操作系统1.3.个人版本和服务器版本的区别1.4.Linux的各个版本 2.安装VMWare虚拟机1.VMWare虚拟机介绍2.VMWare虚拟机安装3.VMWare虚拟机配置 3.安装配置Windows Server 2012 R24.完成电脑远程访问电脑5.服务器环境搭建配置jdk配置…...
远程连接异地主机可能遇到的问题及处理
0.现状 公司的一套系统内部有多个节点的内网,要把数据上传至客户的办公网环境中的服务器。客户办公网为我们提供了一台类似路由的设备,办公网无法让内网地址的数据包透传至服务器。现场条件所限,只有有限数量的技术服务人员可以维持…...
使用 PointNet 进行3D点集(即点云)的分类
点云分类 介绍 无序3D点集(即点云)的分类、检测和分割是计算机视觉中的核心问题。此示例实现了开创性的点云深度学习论文PointNet(Qi 等人,2017)。 设置 如果使用 colab 首先安装 trimesh !pip install trimesh。 import os import glob import trimesh import numpy as…...
高通平台GPIO引脚复用指导
高通平台GPIO引脚复用指导 1. 概述1.1 平台有多少个GPIO?1.2 这些GPIO都有哪些可复用的功能? 2. 软件配置2.1 TZ侧GPIO配置2.2 SBL侧GPIO配置2.3 AP侧GPIO配置2.3.1 Linux DTS机制与设备驱动模型概述2.3.2高通平台的pinctrl控制器2.3.2.1 SDX12 CPU pinc…...
华为机试题:HJ5 进制转换
目录 第一章、算法题1.1)题目描述1.2)解题思路与答案1.3)派仔的解题思路与答案1.3)牛客链接 友情提醒: 先看文章目录,大致了解文章知识点结构,点击文章目录可直接跳转到文章指定位置。 第一章、算法题 1.…...
大数据学习栈记——Neo4j的安装与使用
本文介绍图数据库Neofj的安装与使用,操作系统:Ubuntu24.04,Neofj版本:2025.04.0。 Apt安装 Neofj可以进行官网安装:Neo4j Deployment Center - Graph Database & Analytics 我这里安装是添加软件源的方法 最新版…...
css实现圆环展示百分比,根据值动态展示所占比例
代码如下 <view class""><view class"circle-chart"><view v-if"!!num" class"pie-item" :style"{background: conic-gradient(var(--one-color) 0%,#E9E6F1 ${num}%),}"></view><view v-else …...
基于ASP.NET+ SQL Server实现(Web)医院信息管理系统
医院信息管理系统 1. 课程设计内容 在 visual studio 2017 平台上,开发一个“医院信息管理系统”Web 程序。 2. 课程设计目的 综合运用 c#.net 知识,在 vs 2017 平台上,进行 ASP.NET 应用程序和简易网站的开发;初步熟悉开发一…...
ardupilot 开发环境eclipse 中import 缺少C++
目录 文章目录 目录摘要1.修复过程摘要 本节主要解决ardupilot 开发环境eclipse 中import 缺少C++,无法导入ardupilot代码,会引起查看不方便的问题。如下图所示 1.修复过程 0.安装ubuntu 软件中自带的eclipse 1.打开eclipse—Help—install new software 2.在 Work with中…...
CMake 从 GitHub 下载第三方库并使用
有时我们希望直接使用 GitHub 上的开源库,而不想手动下载、编译和安装。 可以利用 CMake 提供的 FetchContent 模块来实现自动下载、构建和链接第三方库。 FetchContent 命令官方文档✅ 示例代码 我们将以 fmt 这个流行的格式化库为例,演示如何: 使用 FetchContent 从 GitH…...
使用Matplotlib创建炫酷的3D散点图:数据可视化的新维度
文章目录 基础实现代码代码解析进阶技巧1. 自定义点的大小和颜色2. 添加图例和样式美化3. 真实数据应用示例实用技巧与注意事项完整示例(带样式)应用场景在数据科学和可视化领域,三维图形能为我们提供更丰富的数据洞察。本文将手把手教你如何使用Python的Matplotlib库创建引…...
React---day11
14.4 react-redux第三方库 提供connect、thunk之类的函数 以获取一个banner数据为例子 store: 我们在使用异步的时候理应是要使用中间件的,但是configureStore 已经自动集成了 redux-thunk,注意action里面要返回函数 import { configureS…...
基于Java+MySQL实现(GUI)客户管理系统
客户资料管理系统的设计与实现 第一章 需求分析 1.1 需求总体介绍 本项目为了方便维护客户信息为了方便维护客户信息,对客户进行统一管理,可以把所有客户信息录入系统,进行维护和统计功能。可通过文件的方式保存相关录入数据,对…...
Linux 内存管理实战精讲:核心原理与面试常考点全解析
Linux 内存管理实战精讲:核心原理与面试常考点全解析 Linux 内核内存管理是系统设计中最复杂但也最核心的模块之一。它不仅支撑着虚拟内存机制、物理内存分配、进程隔离与资源复用,还直接决定系统运行的性能与稳定性。无论你是嵌入式开发者、内核调试工…...
苹果AI眼镜:从“工具”到“社交姿态”的范式革命——重新定义AI交互入口的未来机会
在2025年的AI硬件浪潮中,苹果AI眼镜(Apple Glasses)正在引发一场关于“人机交互形态”的深度思考。它并非简单地替代AirPods或Apple Watch,而是开辟了一个全新的、日常可接受的AI入口。其核心价值不在于功能的堆叠,而在于如何通过形态设计打破社交壁垒,成为用户“全天佩戴…...
