当前位置: 首页 > news >正文

BERT模型核心组件详解及其实现

摘要

BERT(Bidirectional Encoder Representations from Transformers)是一种基于Transformer架构的预训练模型,在自然语言处理领域取得了显著的成果。本文详细介绍了BERT模型中的几个关键组件及其实现,包括激活函数、变量初始化、嵌入查找、层归一化等。通过深入理解这些组件,读者可以更好地掌握BERT模型的工作原理,并在实际应用中进行优化和调整。

1. 引言

BERT模型由Google研究人员于2018年提出,通过大规模的无监督预训练和任务特定的微调,显著提升了多个自然语言处理任务的性能。本文将重点介绍BERT模型中的几个核心组件,包括激活函数、变量初始化、嵌入查找、层归一化等,并提供相应的代码实现。

2. 激活函数
2.1 Gaussian Error Linear Unit (GELU)

GELU是一种平滑的ReLU变体,其数学表达式如下:

GELU(x)=x⋅Φ(x)GELU(x)=x⋅Φ(x)

其中,Φ(x)Φ(x)是标准正态分布的累积分布函数(CDF)。在TensorFlow中,GELU可以通过以下方式实现:

def gelu(x):"""Gaussian Error Linear Unit.This is a smoother version of the RELU.Original paper: https://arxiv.org/abs/1606.08415Args:x: float Tensor to perform activation.Returns:`x` with the GELU activation applied."""cdf = 0.5 * (1.0 + tf.tanh((np.sqrt(2 / np.pi) * (x + 0.044715 * tf.pow(x, 3)))))return x * cdf
2.2 激活函数映射

为了方便使用不同的激活函数,我们定义了一个映射函数get_activation,该函数根据传入的字符串返回相应的激活函数:

def get_activation(activation_string):"""Maps a string to a Python function, e.g., "relu" => `tf.nn.relu`.Args:activation_string: String name of the activation function.Returns:A Python function corresponding to the activation function. If`activation_string` is None, empty, or "linear", this will return None.If `activation_string` is not a string, it will return `activation_string`.Raises:ValueError: The `activation_string` does not correspond to a knownactivation."""if not isinstance(activation_string, six.string_types):return activation_stringif not activation_string:return Noneact = activation_string.lower()if act == "linear":return Noneelif act == "relu":return tf.nn.reluelif act == "gelu":return geluelif act == "tanh":return tf.tanhelse:raise ValueError("Unsupported activation: %s" % act)
3. 变量初始化

在深度学习中,合理的变量初始化对于模型的收敛速度和最终性能至关重要。BERT模型中使用了截断正态分布初始化器(truncated_normal_initializer),其标准差为0.02:

def create_initializer(initializer_range=0.02):"""Creates a `truncated_normal_initializer` with the given range."""return tf.truncated_normal_initializer(stddev=initializer_range)
4. 嵌入查找

嵌入查找是将输入的token id转换为向量表示的过程。BERT模型中使用了两种方法:一种是使用tf.gather(),另一种是使用one-hot编码:

def embedding_lookup(input_ids,vocab_size,embedding_size=128,initializer_range=0.02,word_embedding_name="word_embeddings",use_one_hot_embeddings=False):"""Looks up words embeddings for id tensor.Args:input_ids: int32 Tensor of shape [batch_size, seq_length] containing wordids.vocab_size: int. Size of the embedding vocabulary.embedding_size: int. Width of the word embeddings.initializer_range: float. Embedding initialization range.word_embedding_name: string. Name of the embedding table.use_one_hot_embeddings: bool. If True, use one-hot method for wordembeddings. If False, use `tf.gather()`.Returns:float Tensor of shape [batch_size, seq_length, embedding_size]."""if input_ids.shape.ndims == 2:input_ids = tf.expand_dims(input_ids, axis=[-1])embedding_table = tf.get_variable(name=word_embedding_name,shape=[vocab_size, embedding_size],initializer=create_initializer(initializer_range))flat_input_ids = tf.reshape(input_ids, [-1])if use_one_hot_embeddings:one_hot_input_ids = tf.one_hot(flat_input_ids, depth=vocab_size)output = tf.matmul(one_hot_input_ids, embedding_table)else:output = tf.gather(embedding_table, flat_input_ids)input_shape = get_shape_list(input_ids)output = tf.reshape(output,input_shape[0:-1] + [input_shape[-1] * embedding_size])return (output, embedding_table)
5. 层归一化

层归一化(Layer Normalization)是一种常用的归一化技术,用于加速训练过程并提高模型的泛化能力。BERT模型中使用了tf.contrib.layers.layer_norm来进行层归一化:

def layer_norm(input_tensor, name=None):"""Run layer normalization on the last dimension of the tensor."""return tf.contrib.layers.layer_norm(inputs=input_tensor, begin_norm_axis=-1, begin_params_axis=-1, scope=name)

为了方便使用,我们还定义了一个组合函数layer_norm_and_dropout,该函数先进行层归一化,再进行dropout操作:

def layer_norm_and_dropout(input_tensor, dropout_prob, name=None):"""Runs layer normalization followed by dropout."""output_tensor = layer_norm(input_tensor, name)output_tensor = dropout(output_tensor, dropout_prob)return output_tensor
6. Dropout

Dropout是一种常用的正则化技术,用于防止模型过拟合。在BERT模型中,dropout的概率可以通过配置参数进行设置:

def dropout(input_tensor, dropout_prob):"""Perform dropout.Args:input_tensor: float Tensor.dropout_prob: Python float. The probability of dropping out a value (NOT of*keeping* a dimension as in `tf.nn.dropout`).Returns:A version of `input_tensor` with dropout applied."""if dropout_prob is None or dropout_prob == 0.0:return input_tensoroutput = tf.nn.dropout(input_tensor, 1.0 - dropout_prob)return output
7. 从检查点加载变量

在微调过程中,通常需要从预训练的模型检查点中加载变量。get_assignment_map_from_checkpoint函数用于计算当前变量与检查点变量的映射关系:

def get_assignment_map_from_checkpoint(tvars, init_checkpoint):"""Compute the union of the current variables and checkpoint variables."""assignment_map = {}initialized_variable_names = {}name_to_variable = collections.OrderedDict()for var in tvars:name = var.namem = re.match("^(.*):\\d+$", name)if m is not None:name = m.group(1)name_to_variable[name] = varinit_vars = tf.train.list_variables(init_checkpoint)assignment_map = collections.OrderedDict()for x in init_vars:(name, var) = (x[0], x[1])if name not in name_to_variable:continueassignment_map[name] = nameinitialized_variable_names[name] = 1initialized_variable_names[name + ":0"] = 1return (assignment_map, initialized_variable_names)
8. 结论

本文详细介绍了BERT模型中的几个核心组件,包括激活函数、变量初始化、嵌入查找、层归一化等。通过深入理解这些组件,读者可以更好地掌握BERT模型的工作原理,并在实际应用中进行优化和调整。希望本文能为读者在自然语言处理领域的研究和开发提供有益的参考。

相关文章:

BERT模型核心组件详解及其实现

摘要 BERT(Bidirectional Encoder Representations from Transformers)是一种基于Transformer架构的预训练模型,在自然语言处理领域取得了显著的成果。本文详细介绍了BERT模型中的几个关键组件及其实现,包括激活函数、变量初始化…...

图论-代码随想录刷题记录[JAVA]

文章目录 前言深度优先搜索理论基础所有可达路径岛屿数量岛屿最大面积孤岛的总面积沉默孤岛Floyd 算法dijkstra(朴素版)最小生成树之primkruskal算法 前言 新手小白记录第一次刷代码随想录 1.自用 抽取精简的解题思路 方便复盘 2.代码尽量多加注释 3.记录…...

c#加载shellcode

本地加载bin文件 SharpPELoader项目如下: using System; using System.IO; using System.Runtime.InteropServices;namespace TestShellCode {internal class Program{private const uint MEM_COMMIT 0x1000;private const uint PAGE_EXECUTE_READWRITE 0x40;pr…...

HarmonyOS 开发环境搭建

HarmonyOS(鸿蒙操作系统)作为一种面向全场景多设备的智能操作系统,正逐渐在市场上崭露头角。为了进入HarmonyOS生态,开发者需要搭建一个高效的开发环境。本文将详细介绍如何搭建HarmonyOS开发环境,特别是如何安装和配置…...

【网络云计算】2024第46周周考-磁盘管理的基础知识-RAID篇

文章目录 1、画出各个RAID的结构图,6句话说明优点和缺点,以及磁盘可用率和坏盘数量,磁盘总的数量2、写出TCP五层模型以及对应的常用协议 【网络云计算】2024第46周周考-磁盘管理的基础知识-RAID篇 1、画出各个RAID的结构图,6句话说…...

深入理解 SQL_MODE 之 ANSI_QUOTES

引言 在 MySQL 数据库中,sql_mode 是一个重要的配置参数,它定义了 MySQL 应该遵循的 SQL 语法标准以及数据验证规则。其中,ANSI_QUOTES 是 sql_mode 中的一个重要选项,它改变了 MySQL 对于字符串和标识符的识别方式,使…...

容器技术在持续集成与持续交付中的应用

💓 博客主页:瑕疵的CSDN主页 📝 Gitee主页:瑕疵的gitee主页 ⏩ 文章专栏:《热点资讯》 容器技术在持续集成与持续交付中的应用 容器技术在持续集成与持续交付中的应用 容器技术在持续集成与持续交付中的应用 引言 容器…...

【嵌入式软件-STM32】OLED显示屏+调试方法

目录 一、调试方式 1)串口调试 优势 弊端 2)显示屏调试 优势 弊端 3)Keil调试模式 4)点灯调试法 5)注释调试法 6)对照法 二、OLED简介 OLED组件 OLED显示屏 0.96寸OLED模块 OLED外观和种类…...

kubernetes简单入门实战

本章将介绍如何在kubernetes集群中部署一个nginx服务,并且能够对其访问 Namespace Namespace是k8s系统中一个非常重要的资源,它的主要作用是用来实现多套环境的资源隔离或者多租户的资源隔离。 默认情况下,k8s集群中的所有的Pod都是可以相…...

Python连接Mysql、Postgre、ClickHouse、Redis常用库及封装方法

博主在这里分享一些常见的python连接数据库或中间件的库和封装方案,希望对大家有用。 Mysql封装 #!/usr/bin/python # -*- coding: utf-8 -*- import sys import pymysql from settings import MYSQL_DB, MYSQL_PORT, MYSQL_USER, MYSQL_PASSWORD, MYSQL_HOST, EN…...

如何修改npm包

前言 开发中遇到一个问题,配置 Element Plus 自定义主题时,添加了 ElementPlusResolver({ importStyle: "sass" }) 后,控制台出现报错,这是因为 Dart Sass 2.0 不再支持使用 !global 来声明新变量,虽然当前…...

Django 2024全栈开发指南(三):数据库模型与ORM操作(上篇)

目录 一、模型的定义二、数据迁移三、数据表关系四、数据表操作4.1 Shell工具4.2 数据新增4.3 数据修改4.4 数据删除4.5 数据查询 Django 对各种数据库提供了很好的支持,包括 PostgreSQL、MySQL、SQLite 和 Oracle,而且为这些数据库提供了统一的 API 方法…...

低代码可视化-uniapp开关选择组件-低码生成器

开关(Switch)选择组件是一种用户界面元素,允许用户在两种状态(通常是开/关、是/否、启用/禁用等)之间进行切换。这种组件在移动应用、桌面软件、网页以及物联网设备中广泛应用。以下是对开关Switch选择组件的详细介绍&…...

【arxiv‘24】Vision-Language Navigation with Continual Learning

论文信息 题目:Vision-Language Navigation with Continual Learning 视觉-语言导航与持续学习 作者:Zhiyuan Li, Yanfeng Lv, Ziqin Tu, Di Shang, Hong Qiao 论文创新点 VLNCL范式:这是一个新颖的框架,它使得智能体能够在适…...

如何在 Ubuntu 上安装 Jupyter Notebook

本篇文章将教你在 Ubuntu 服务器上安装 Jupyter Notebook,并使用 Nginx 和 SSL 证书进行安全配置。 我将带你一步步在云服务器上搭建 Jupyter Notebook 服务器。Jupyter Notebook 在数据科学和机器学习领域被广泛用于交互式编码、可视化和实验。在远程服务器上运行…...

免费申请 Let‘s Encrypt SSL 证书

免费申请 Lets Encrypt SSL 证书 在网络安全日益重要的今天,为网站启用 SSL 证书是保障数据安全和用户信任的关键。Lets Encrypt 提供的免费 SSL 证书是一个很好的选择。下面我们详细介绍如何为网站域名申请该证书。 一、准备工作 域名 确保已注册要使用 SSL 证书的…...

【JAVA】Java基础—面向对象编程:继承—重写父类方法

在Java开发中,重写(Override)是面向对象编程(OOP)中的一个重要概念。它允许子类提供父类方法的具体实现,从而改变或扩展父类的行为。重写是实现多态性的重要手段,使得程序在运行时能够根据对象的…...

【C++初阶】C++入门

1、C第一个程序 C是脱胎于C语言的&#xff0c;所以也包含了C语言绝大多数的内容&#xff0c;C兼容C语言绝大多数的语法,在C语言中能实现的程序在C中也是可以执行的&#xff0c;但需要将定义文件代码的后缀改为.cpp 就比如hello world程序 // test.cpp #include<stdio.h&g…...

自然推理系统:的拒取式的解析

要推导出 **"非A"** 的拒取式 (rejection form)&#xff0c;首先我们要理解逻辑推理中几个基本的概念。 假设我们有以下前提&#xff1a; 1. **A → B** &#xff08;如果A成立&#xff0c;那么B成立&#xff09; 2. **非B** &#xff08;B不成立&#xff09; 我们…...

OceanBase 分区表详解

1、分区表的定义 在OceanBase数据库中&#xff0c;普通的表数据可以根据预设的规则被分割并存储到不同的数据区块中&#xff0c;同一区块的数据是在一个物理存储上。这样被分区块的表被称为分区表&#xff0c;而其中的每一个独立的数据区块则被称为一个分区。 如下图所示&…...

从原型到实战:基于快马生成代码快速开发可用的worldmonitor疫情监控系统

从原型到实战&#xff1a;基于快马生成代码快速开发可用的worldmonitor疫情监控系统 最近在做一个全球疫情数据监控系统的项目&#xff0c;正好用到了InsCode(快马)平台来快速生成基础代码&#xff0c;然后在这个基础上进行二次开发。整个过程非常顺畅&#xff0c;特别是平台的…...

储能电站EMS系统实战指南:从硬件选型到软件配置的完整避坑手册

储能电站EMS系统实战指南&#xff1a;从硬件选型到软件配置的完整避坑手册 在新能源行业快速发展的今天&#xff0c;储能电站作为电力系统中的关键调节单元&#xff0c;其能量管理系统&#xff08;EMS&#xff09;的稳定性和智能化水平直接决定了电站的经济效益和运行安全。然而…...

原神帧率解锁终极指南:3步轻松突破60FPS限制,享受极致流畅体验

原神帧率解锁终极指南&#xff1a;3步轻松突破60FPS限制&#xff0c;享受极致流畅体验 【免费下载链接】genshin-fps-unlock unlocks the 60 fps cap 项目地址: https://gitcode.com/gh_mirrors/ge/genshin-fps-unlock 还在为原神60帧限制而苦恼吗&#xff1f;高端显卡却…...

洛雪音乐音源项目:免费高品质音乐资源获取的终极方案

洛雪音乐音源项目&#xff1a;免费高品质音乐资源获取的终极方案 【免费下载链接】lxmusic- lxmusic(洛雪音乐)全网最新最全音源 项目地址: https://gitcode.com/gh_mirrors/lx/lxmusic- 1 价值定位&#xff1a;重新定义音乐资源获取体验 洛雪音乐音源项目作为一款开源…...

AI推动SEO关键词优化的全新策略与实践明晰

在当前数字营销环境中&#xff0c;AI技术为SEO关键词优化带来了前所未有的变革。它通过自动化的数据分析与挖掘工具&#xff0c;能够帮助企业更准确地识别用户需求与搜索趋势。通过AI的支持&#xff0c;关键词挖掘变得更加高效和精准&#xff0c;企业可以快速获取相关关键词并优…...

告别图库!用LiuJuan Z-Image为文章博客自动生成配图(保姆级教程)

告别图库&#xff01;用LiuJuan Z-Image为文章博客自动生成配图&#xff08;保姆级教程&#xff09; 1. 为什么你需要这个工具&#xff1f; 作为一名内容创作者&#xff0c;我深知找配图的痛苦。记得上周为了给一篇技术文章配图&#xff0c;我花了整整40分钟在图库里翻找&…...

小白也能学会:MogFace透明蒙版可视化,人脸检测不再难

小白也能学会&#xff1a;MogFace透明蒙版可视化&#xff0c;人脸检测不再难 1. 为什么需要透明蒙版可视化&#xff1f; 想象一下这样的场景&#xff1a;你拍了一张全家福&#xff0c;想用AI工具检测照片中有多少人。传统的检测工具会在每个人脸上画一个绿色的方框&#xff0…...

避坑指南:微信支付V3 SDK自动更新证书失败的5种常见原因及修复方法

微信支付V3证书自动更新失败排查手册&#xff1a;从原理到实战修复 微信支付的V3版本SDK以其自动证书更新机制著称&#xff0c;但不少开发者在集成过程中都遭遇过AutoUpdateCertificatesVerifier的失败问题。证书更新失败不仅会导致支付功能中断&#xff0c;还可能引发验签错误…...

元宇宙拆迁队:强拆违规建筑日入十万

从Bug猎人到空间执法官当传统的软件测试工程师还在为揪出一个隐蔽的NullPointerException而欢欣鼓舞时&#xff0c;一片更为广阔、也更为凶险的新战场已经悄然开启——元宇宙。在这里&#xff0c;代码的缺陷不再仅仅导致程序崩溃或数据丢失&#xff0c;它们会具象化为扭曲的空间…...

Hunyuan-MT-7B效果实测:Pixel Language Portal对中文网络用语、方言、谐音梗的跨维转码能力分析

Hunyuan-MT-7B效果实测&#xff1a;Pixel Language Portal对中文网络用语、方言、谐音梗的跨维转码能力分析 1. 引言&#xff1a;当翻译遇上像素冒险 在数字时代的语言交流中&#xff0c;传统翻译工具往往显得生硬而缺乏温度。Pixel Language Portal&#xff08;像素语言跨维…...