当前位置: 首页 > article >正文

别再死记硬背BERT结构了!用PyTorch手搓一个BERT-Base,带你彻底搞懂MLM和NSP

从零实现BERT-Base深入解析MLM与NSP的PyTorch实战指南1. 为什么需要动手实现BERT在自然语言处理领域BERT已经成为基石般的模型架构。但很多开发者发现仅仅通过调用transformers库来使用BERT就像驾驶一辆无法打开引擎盖的跑车——你可以踩油门前进却对内部工作原理一无所知。理解BERT的核心价值在于80-10-10掩码策略的巧妙设计如何解决预训练与微调的数据分布差异三种嵌入相加的数学本质及其对位置信息的编码方式注意力头之间的参数共享机制如何影响模型表现层归一化的放置位置为何比Transformer原始论文更有效当我第一次尝试修改BERT的注意力头大小时才真正意识到那些看似简单的架构决策背后蕴含的深刻工程智慧。下面让我们用PyTorch从零开始构建一个完整可训练的BERT-Base模型。2. 模型架构设计2.1 嵌入层实现BERT的嵌入层由三个部分组成它们的数学表达可以表示为$$ \text{Embedding} \text{TokenEmbedding} \text{SegmentEmbedding} \text{PositionEmbedding} $$class BERTEmbeddings(nn.Module): def __init__(self, config): super().__init__() self.token_embeddings nn.Embedding(config.vocab_size, config.hidden_size) self.position_embeddings nn.Embedding(config.max_position_embeddings, config.hidden_size) self.segment_embeddings nn.Embedding(config.type_vocab_size, config.hidden_size) self.LayerNorm nn.LayerNorm(config.hidden_size) self.dropout nn.Dropout(config.hidden_dropout_prob) def forward(self, input_ids, token_type_idsNone): seq_length input_ids.size(1) position_ids torch.arange(seq_length, dtypetorch.long, deviceinput_ids.device) if token_type_ids is None: token_type_ids torch.zeros_like(input_ids) token_emb self.token_embeddings(input_ids) position_emb self.position_embeddings(position_ids) segment_emb self.segment_embeddings(token_type_ids) embeddings token_emb position_emb segment_emb embeddings self.LayerNorm(embeddings) embeddings self.dropout(embeddings) return embeddings关键细节位置嵌入是可学习的参数而非固定正弦函数这是BERT与原始Transformer的重要区别2.2 Transformer编码器层每个编码器层包含多头自注意力机制前馈神经网络残差连接和层归一化class BERTSelfAttention(nn.Module): def __init__(self, config): super().__init__() self.num_heads config.num_attention_heads self.head_dim config.hidden_size // config.num_attention_heads self.query nn.Linear(config.hidden_size, config.hidden_size) self.key nn.Linear(config.hidden_size, config.hidden_size) self.value nn.Linear(config.hidden_size, config.hidden_size) self.dense nn.Linear(config.hidden_size, config.hidden_size) self.dropout nn.Dropout(config.attention_probs_dropout_prob) def forward(self, hidden_states, attention_maskNone): batch_size hidden_states.size(0) # 线性变换 q self.query(hidden_states) k self.key(hidden_states) v self.value(hidden_states) # 多头分割 q q.view(batch_size, -1, self.num_heads, self.head_dim).transpose(1, 2) k k.view(batch_size, -1, self.num_heads, self.head_dim).transpose(1, 2) v v.view(batch_size, -1, self.num_heads, self.head_dim).transpose(1, 2) # 注意力分数计算 scores torch.matmul(q, k.transpose(-1, -2)) / math.sqrt(self.head_dim) if attention_mask is not None: scores scores attention_mask # 注意力概率 probs nn.Softmax(dim-1)(scores) probs self.dropout(probs) # 上下文加权 context torch.matmul(probs, v) context context.transpose(1, 2).contiguous() context context.view(batch_size, -1, self.num_heads * self.head_dim) # 输出投影 output self.dense(context) return output3. 预训练任务实现3.1 掩码语言模型(MLM)BERT的MLM任务采用独特的80-10-10策略处理方式比例示例 (原始句子: the man ate an apple)替换为[MASK]80%the man [MASK] an apple替换为随机词10%the man ran an apple保持原词10%the man ate an appledef create_masked_lm_predictions(tokens, mask_prob, vocab_size): 生成MLM训练样本 output_tokens list(tokens) masked_lm_positions [] masked_lm_labels [] for i, token in enumerate(tokens): if token in [[CLS], [SEP]]: continue prob random.random() if prob mask_prob: masked_lm_positions.append(i) mask_decision random.random() if mask_decision 0.8: output_tokens[i] [MASK] elif mask_decision 0.9: output_tokens[i] random.randint(0, vocab_size-1) # 剩下10%保持原样 masked_lm_labels.append(token) return output_tokens, masked_lm_positions, masked_lm_labels3.2 下一句预测(NSP)NSP任务的样本构造规则def create_next_sentence_predictions(text_a, text_b, max_seq_length): 生成NSP训练样本 # 50%概率使用真实下一句 if random.random() 0.5: is_next True tokens_a tokenize(text_a) tokens_b tokenize(text_b) else: is_next False tokens_a tokenize(text_a) tokens_b tokenize(random.choice(corpus)) # 随机选择非关联句子 # 合并并截断 truncate_seq_pair(tokens_a, tokens_b, max_seq_length - 3) # 添加特殊token tokens [[CLS]] tokens_a [[SEP]] tokens_b [[SEP]] segment_ids [0]*(len(tokens_a)2) [1]*(len(tokens_b)1) return tokens, segment_ids, is_next4. 完整模型整合将各组件组合成完整BERT模型class BERTForPretraining(nn.Module): def __init__(self, config): super().__init__() self.bert BERTModel(config) self.mlm_head MaskedLMHead(config) self.nsp_head NextSentencePredictionHead(config) def forward(self, input_ids, token_type_idsNone, attention_maskNone, masked_lm_positionsNone): # 获取BERT输出 sequence_output, pooled_output self.bert( input_ids, token_type_ids, attention_mask) # MLM任务 if masked_lm_positions is not None: masked_lm_output torch.gather( sequence_output, 1, masked_lm_positions.unsqueeze(-1).expand(-1,-1,sequence_output.size(-1))) mlm_scores self.mlm_head(masked_lm_output) else: mlm_scores None # NSP任务 nsp_scores self.nsp_head(pooled_output) return mlm_scores, nsp_scores5. 训练技巧与优化5.1 动态掩码策略原始BERT在数据预处理时生成掩码更高效的做法是在训练时动态生成class DynamicMasking: def __init__(self, mask_prob0.15): self.mask_prob mask_prob def apply(self, batch): masked_batch batch.clone() labels torch.full_like(batch, -100) # 忽略非掩码位置 # 为每个序列生成随机掩码 rand torch.rand(batch.shape) mask_pos (rand self.mask_prob) (batch ! 0) # 忽略padding # 80-10-10策略 mask_decision torch.rand(batch.shape) masked_batch[mask_pos (mask_decision 0.8)] tokenizer.mask_token_id random_words torch.randint(0, tokenizer.vocab_size, batch.shape) masked_batch[mask_pos (mask_decision 0.8) (mask_decision 0.9)] ( random_words[mask_pos (mask_decision 0.8) (mask_decision 0.9)]) labels[mask_pos] batch[mask_pos] return masked_batch, labels5.2 梯度累积当GPU内存不足时可以使用梯度累积模拟更大batch sizeaccumulation_steps 4 optimizer.zero_grad() for i, batch in enumerate(dataloader): loss model(batch).mean() loss.backward() if (i1) % accumulation_steps 0: optimizer.step() optimizer.zero_grad()6. 性能优化技巧6.1 混合精度训练使用AMP(Automatic Mixed Precision)加速训练scaler torch.cuda.amp.GradScaler() with torch.cuda.amp.autocast(): loss model(batch) scaler.scale(loss).backward() scaler.step(optimizer) scaler.update()6.2 注意力优化实现内存高效的注意力计算def memory_efficient_attention(q, k, v, maskNone): 分块计算注意力以减少内存占用 chunk_size 64 # 根据GPU内存调整 scores torch.einsum(bhid,bhjd-bhij, q, k) / math.sqrt(q.size(-1)) if mask is not None: scores scores mask probs torch.softmax(scores, dim-1) # 分块计算 output torch.zeros_like(v) for i in range(0, q.size(2), chunk_size): chunk torch.einsum(bhij,bhjd-bhid, probs[:,:,i:ichunk_size], v[:,:,i:ichunk_size]) output[:,:,i:ichunk_size] chunk return output7. 模型部署实践7.1 权重共享技巧# 在初始化时共享权重 self.mlm_head.dense.weight self.bert.embeddings.token_embeddings.weight7.2 ONNX导出将模型导出为ONNX格式以便生产环境部署torch.onnx.export( model, (dummy_input,), bert.onnx, input_names[input_ids, attention_mask], output_names[output], dynamic_axes{ input_ids: {0: batch, 1: sequence}, attention_mask: {0: batch, 1: sequence}, output: {0: batch} } )

相关文章:

别再死记硬背BERT结构了!用PyTorch手搓一个BERT-Base,带你彻底搞懂MLM和NSP

从零实现BERT-Base:深入解析MLM与NSP的PyTorch实战指南 1. 为什么需要动手实现BERT? 在自然语言处理领域,BERT已经成为基石般的模型架构。但很多开发者发现,仅仅通过调用transformers库来使用BERT,就像驾驶一辆无法打开…...

Pypy虚拟环境配置避坑指南:用venv管理依赖,告别与系统Python的冲突

Pypy虚拟环境配置避坑指南:用venv管理依赖,告别与系统Python的冲突 当你第一次在项目中使用Pypy时,可能会被它惊人的执行速度所震撼——特别是在处理数值计算或长时间运行的任务时。但随之而来的依赖管理问题往往让人头疼:为什么用…...

CLIP-GmP-ViT-L-14惊艳效果:脑电图波形→认知状态/异常放电/临床诊断文本

CLIP-GmP-ViT-L-14惊艳效果:脑电图波形→认知状态/异常放电/临床诊断文本 1. 模型能力概览 CLIP-GmP-ViT-L-14是一个经过几何参数化(GmP)微调的CLIP模型,在医学影像分析领域展现出惊人的能力。这个模型最引人注目的特点是能够将脑电图(EEG)波形直接转化…...

【卷积】通道数不变时,1x1与3x3卷积:从感受野到计算效率的深度对比

1. 感受野与特征提取能力的本质差异 当我们在设计卷积神经网络时,选择1x1还是3x3卷积核绝不是随意决定的。这两种看似简单的操作,在实际应用中会产生截然不同的效果。我刚开始接触深度学习时,曾经天真地认为"反正通道数不变,…...

通义千问1.5-1.8B-Chat-GPTQ-Int4环境部署:Anaconda创建独立Python运行环境

通义千问1.5-1.8B-Chat-GPTQ-Int4环境部署:Anaconda创建独立Python运行环境 想试试通义千问这个轻量级大模型,结果第一步就被环境依赖搞晕了?PyTorch版本不对、CUDA不匹配、各种包冲突报错,是不是让你头大? 别担心&a…...

基于VSG分布式能源并网仿真:有功频率与无功电压控制的完美波形实现(MATLAB 2021b版)

基于虚拟同步发电机(vsg)分布式能源并网仿真 并网逆变器,有功频率控制,无功电压控制,VSG控制,电压电流双环PI控制!! 各方面波形都完美 MATLAB2021b最近在研究基于虚拟同步发电机&…...

西安电子科技大学计算机考研复试攻略:笔试与机试成绩深度解析

1. 西安电子科技大学计算机考研复试概况 西安电子科技大学计算机科学与技术学院的考研复试一直以严格规范著称,其中笔试和机试环节尤为关键。作为参加过复试的过来人,我深刻体会到这两个环节对最终录取结果的决定性影响。根据近三年的数据统计&#xff0…...

告别虚拟机!用WinSniffer v1.5 + MT7921网卡在Windows原生抓取WiFi 6E/7的6GHz报文

Windows原生抓取WiFi 6E/7的6GHz报文实战指南:WinSniffer v1.5与MT7921网卡完美组合 在无线网络技术快速迭代的今天,WiFi 6E和WiFi 7带来的6GHz频段为高速低延迟通信开辟了新天地。但对于网络工程师和技术爱好者而言,如何高效捕获和分析这些高…...

前端工程化实战:项目亮点与技术难点深度解析

1. 前端工程化的核心价值与实践路径 十年前我刚入行时,前端开发还停留在"切图写jQuery"的阶段。如今随着业务复杂度提升,一个中型前端项目就可能涉及上百个组件、数十个第三方依赖。这种背景下,工程化不再是可选项,而是…...

记录一次前端模型利用freesql映射,报400的问题

前端代码如下: <template> <div> <el-row style="margin-top: 16px"> <el-col :span="6" style="margin-left: 16px"> <span class="font-col" style="width: 100px">名称:</span> …...

Kandinsky-5.0-I2V-Lite-5s效果对比:不同采样步数(12/24/36)生成质量与耗时分析

Kandinsky-5.0-I2V-Lite-5s效果对比&#xff1a;不同采样步数&#xff08;12/24/36&#xff09;生成质量与耗时分析 1. 模型简介与测试背景 Kandinsky-5.0-I2V-Lite-5s是一款轻量级图生视频模型&#xff0c;只需上传一张首帧图片并补充运动或镜头描述&#xff0c;就能生成约5…...

Qwen1.5-0.5B-Chat部署全记录:从环境搭建到上线完整步骤

Qwen1.5-0.5B-Chat部署全记录&#xff1a;从环境搭建到上线完整步骤 1. 项目概述 Qwen1.5-0.5B-Chat是阿里通义千问开源系列中的轻量级对话模型&#xff0c;仅有5亿参数却具备出色的对话能力。这个模型特别适合资源有限的部署环境&#xff0c;可以在普通CPU服务器上流畅运行&…...

阿里通义Z-Image-Turbo WebUI全攻略:参数设置+提示词技巧,小白也能出大片

阿里通义Z-Image-Turbo WebUI全攻略&#xff1a;参数设置提示词技巧&#xff0c;小白也能出大片 1. 从零开始&#xff1a;你的AI画师已就位 想象一下&#xff0c;你脑子里有个绝妙的画面——一只在樱花树下打盹的橘猫&#xff0c;阳光透过花瓣洒在它毛茸茸的身上。以前要把这…...

终极指南:如何快速检测微信单向好友并一键清理无效社交关系

终极指南&#xff1a;如何快速检测微信单向好友并一键清理无效社交关系 【免费下载链接】WechatRealFriends 微信好友关系一键检测&#xff0c;基于微信ipad协议&#xff0c;看看有没有朋友偷偷删掉或者拉黑你 项目地址: https://gitcode.com/gh_mirrors/we/WechatRealFriend…...

AI-Shoujo HF Patch:全面提升游戏体验的终极解决方案

AI-Shoujo HF Patch&#xff1a;全面提升游戏体验的终极解决方案 【免费下载链接】AI-HF_Patch Automatically translate, uncensor and update AI-Shoujo! 项目地址: https://gitcode.com/gh_mirrors/ai/AI-HF_Patch AI-Shoujo HF Patch是一款专为AI-Shoujo游戏设计的综…...

ABAP开发必知:ROUND函数四舍五入的坑与正确用法(附实例)

ABAP开发必知&#xff1a;ROUND函数四舍五入的坑与正确用法&#xff08;附实例&#xff09; 在SAP系统的ABAP开发中&#xff0c;数值计算是财务、报表等业务模块的核心需求。而ROUND函数作为处理小数位数的常用工具&#xff0c;其行为模式与常规四舍五入存在关键差异——这正是…...

5分钟快速上手KeymouseGo:免费开源鼠标键盘录制工具完全指南

5分钟快速上手KeymouseGo&#xff1a;免费开源鼠标键盘录制工具完全指南 【免费下载链接】KeymouseGo 类似按键精灵的鼠标键盘录制和自动化操作 模拟点击和键入 | automate mouse clicks and keyboard input 项目地址: https://gitcode.com/gh_mirrors/ke/KeymouseGo 还…...

为何 Agent 才是大模型的终极形态:从 Chatbot 到智能体的演进

为何 Agent 才是大模型的终极形态:从 Chatbot 到智能体的演进 副标题:深入解析大语言模型的演进路径、智能体的核心架构与未来发展趋势 摘要/引言 在过去的几年中,人工智能领域经历了前所未有的变革,特别是大语言模型(Large Language Models, LLMs)的出现,彻底改变了我…...

ARM64缓存一致性实战:手把手教你理解PoC和PoU,搞定DMA与JIT编译器的坑

ARM64缓存一致性实战&#xff1a;深入理解PoC与PoU的工程实践 在底层系统开发领域&#xff0c;缓存一致性始终是工程师们面临的核心挑战之一。特别是在ARM64架构下&#xff0c;PoC&#xff08;Point of Coherency&#xff09;和PoU&#xff08;Point of Unification&#xff09…...

从HydroBASINS到USGS:一站式获取与ArcGIS处理全球及美国流域边界数据

1. 全球与美国流域数据源对比与选择 搞水文研究的朋友们都知道&#xff0c;流域边界数据是基础中的基础。我做了十年GIS分析&#xff0c;经常遇到这样的场景&#xff1a;项目涉及跨国流域分析&#xff0c;需要同时处理全球尺度和国家尺度的数据。这时候HydroBASINS和USGS WBD就…...

Win to Go实战:轻松在外接硬盘或移动硬盘上部署Windows系统

1. 为什么你需要Win to Go&#xff1f; 想象一下这样的场景&#xff1a;你正在咖啡馆用笔记本处理工作文档&#xff0c;突然接到通知要去客户现场演示。传统做法是带着笨重的笔记本&#xff0c;或者把文件拷到U盘——但前者太重&#xff0c;后者可能遇到软件不兼容、环境配置缺…...

VB6,VC++ 结构体变量,内存对齐

我用最底层、最直白、最硬核的方式&#xff0c;一次性给你讲透&#xff1a;什么是补齐长度&#xff1f;为什么编译器要乱插空位&#xff1f;你现在问的&#xff0c;是所有编程语言、所有结构体最核心的原理。我保证你看完彻底通透。一、先给你终极结论&#xff08;一句话&#…...

Vivado 2023.1下,用VCS仿真Xilinx PCIe IP与PHY的完整环境搭建教程

Vivado 2023.1与VCS协同仿真&#xff1a;PCIe IP与PHY集成验证全流程实战 在FPGA设计领域&#xff0c;PCIe接口的实现一直是工程师面临的技术高地。随着Xilinx新一代Vivado 2023.1工具的发布&#xff0c;其内置的PCIe IP核与PHY的协同仿真环境搭建流程有了显著优化。本文将深入…...

黑苹果实战进阶:深度解析硬件兼容性与系统优化四大核心问题

黑苹果实战进阶&#xff1a;深度解析硬件兼容性与系统优化四大核心问题 【免费下载链接】Hackintosh Hackintosh long-term maintenance model EFI and installation tutorial 项目地址: https://gitcode.com/gh_mirrors/ha/Hackintosh Hackintosh黑苹果项目为技术爱好者…...

STL体积计算器:3D打印模型体积与重量估算完整指南

STL体积计算器&#xff1a;3D打印模型体积与重量估算完整指南 【免费下载链接】STL-Volume-Model-Calculator STL Volume Model Calculator Python 项目地址: https://gitcode.com/gh_mirrors/st/STL-Volume-Model-Calculator STL-Volume-Model-Calculator 是一个功能强…...

鲲鹏麒麟环境下MySQL5.7离线部署全流程解析

1. 鲲鹏麒麟环境下的MySQL5.7离线部署背景 在国产化技术快速发展的今天&#xff0c;越来越多的企业和机构开始采用基于鲲鹏处理器和麒麟操作系统的解决方案。这种组合在政务、金融等领域尤其常见&#xff0c;因为这些场景对数据安全和系统可控性有着极高的要求。MySQL作为最流行…...

保姆级教程:用中点电流法搞定NPC三电平逆变器的电压平衡(附MATLAB/Simulink仿真)

保姆级实战&#xff1a;中点电流法在NPC三电平逆变器电压平衡中的Simulink仿真全流程 电力电子工程师们对NPC三电平逆变器中的"中点电压漂移"问题一定不陌生——就像试图在跷跷板上平衡两个不同重量的孩子&#xff0c;稍有不慎就会导致系统崩溃。这次我们不谈枯燥的数…...

Modelsim Wave窗口的5个隐藏技巧:让波形调试效率翻倍(附.do文件实战)

Modelsim Wave窗口的5个隐藏技巧&#xff1a;让波形调试效率翻倍&#xff08;附.do文件实战&#xff09; 在数字电路仿真领域&#xff0c;波形调试往往占据工程师70%以上的仿真时间。当设计规模达到百万门级时&#xff0c;如何在Modelsim的Wave窗口中快速定位关键信号、精确测量…...

WinRAR弹窗广告终极去除指南

1. WinRAR弹窗广告为什么让人头疼 每次打开WinRAR都会弹出烦人的广告窗口&#xff0c;这可能是很多用户共同的烦恼。作为一个用了十几年WinRAR的老用户&#xff0c;我完全理解这种困扰。这些弹窗不仅打断工作流程&#xff0c;有时候还会被安全软件误判为恶意程序导致软件闪退。…...

GeoServer进阶指南:多层级TIF地图数据的切片与缓存优化

1. 多层级TIF地图数据发布的核心挑战 第一次接触多层级TIF地图数据发布时&#xff0c;我完全低估了它的复杂性。直到实际项目中遇到地图加载缓慢、层级切换卡顿的问题&#xff0c;才意识到简单的数据发布远不能满足生产需求。多层级TIF通常来自无人机航拍、卫星遥感或专业测绘&…...