当前位置: 首页 > article >正文

别再死记硬背Transformer了!用PyTorch手把手实现一个简易翻译模型(附完整代码)

用PyTorch从零构建Transformer翻译模型代码驱动的深度学习实践如果你已经读过Transformer的论文或看过相关教程却依然对如何实现这个革命性架构感到迷茫那么这篇文章正是为你准备的。我们将避开繁琐的理论推导直接进入代码层面通过构建一个英中翻译模型来掌握Transformer的核心实现技巧。1. 环境准备与数据预处理在开始构建模型之前我们需要准备好开发环境并处理翻译任务所需的数据集。这个阶段虽然看似简单却直接影响后续模型训练的效果。1.1 安装必要的Python库首先确保你的Python环境建议3.8中已安装以下关键库pip install torch torchtext spacy sentencepiece然后下载中英文语言模型用于分词python -m spacy download en_core_web_sm python -m spacy download zh_core_web_sm1.2 构建双语数据集我们将使用IWSLT2017英中翻译数据集它包含约20万条平行句对。以下是数据加载和预处理的完整代码from torchtext.datasets import IWSLT2017 from torchtext.data.utils import get_tokenizer from torchtext.vocab import build_vocab_from_iterator # 初始化分词器 en_tokenizer get_tokenizer(spacy, languageen_core_web_sm) zh_tokenizer get_tokenizer(spacy, languagezh_core_web_sm) def yield_tokens(data_iter, tokenizer, language): for data in data_iter: yield tokenizer(data[language]) # 构建词汇表 train_iter IWSLT2017(splittrain) en_vocab build_vocab_from_iterator(yield_tokens(train_iter, en_tokenizer, en), specials[unk, pad, bos, eos]) zh_vocab build_vocab_from_iterator(yield_tokens(train_iter, zh_tokenizer, zh), specials[unk, pad, bos, eos]) # 设置默认未知词索引 en_vocab.set_default_index(en_vocab[unk]) zh_vocab.set_default_index(zh_vocab[unk])注意实际应用中应考虑限制词汇表大小如保留前30000个高频词以控制模型规模2. Transformer核心组件实现现在我们来构建Transformer的核心模块。与原始论文不同我们会做一些简化以提升代码可读性同时保持架构的关键特性。2.1 位置编码序列顺序的数字化表达Transformer抛弃了RNN的循环结构因此需要显式地注入位置信息。以下是改进后的位置编码实现import math import torch import torch.nn as nn class PositionalEncoding(nn.Module): def __init__(self, d_model: int, max_len: int 5000): super().__init__() position torch.arange(max_len).unsqueeze(1) div_term torch.exp(torch.arange(0, d_model, 2) * (-math.log(10000.0) / d_model)) pe torch.zeros(max_len, d_model) pe[:, 0::2] torch.sin(position * div_term) pe[:, 1::2] torch.cos(position * div_term) self.register_buffer(pe, pe) def forward(self, x: torch.Tensor) - torch.Tensor: 参数: x: 形状为 [batch_size, seq_len, embedding_dim] 的张量 返回: 添加位置编码后的张量 return x self.pe[:x.size(1)]可视化位置编码可以帮助我们理解其工作原理import matplotlib.pyplot as plt plt.figure(figsize(12, 6)) pe PositionalEncoding(256) y pe(torch.zeros(1, 100, 256)) plt.plot(y[0, :, 4:8].data.numpy()) plt.legend([dim %d % p for p in [4,5,6,7]]) plt.title(Positional Encoding Visualization) plt.show()2.2 多头注意力机制的实现多头注意力是Transformer最具创新性的部分下面是其PyTorch实现class MultiHeadAttention(nn.Module): def __init__(self, d_model: int, num_heads: int, dropout: float 0.1): super().__init__() assert d_model % num_heads 0, d_model必须能被num_heads整除 self.d_model d_model self.num_heads num_heads self.d_k d_model // num_heads # 线性变换层 self.w_q nn.Linear(d_model, d_model) self.w_k nn.Linear(d_model, d_model) self.w_v nn.Linear(d_model, d_model) self.w_o nn.Linear(d_model, d_model) self.dropout nn.Dropout(dropout) def scaled_dot_product_attention(self, q, k, v, maskNone): attn_scores torch.matmul(q, k.transpose(-2, -1)) / math.sqrt(self.d_k) if mask is not None: attn_scores attn_scores.masked_fill(mask 0, -1e9) attn_probs torch.softmax(attn_scores, dim-1) attn_probs self.dropout(attn_probs) output torch.matmul(attn_probs, v) return output, attn_probs def forward(self, q, k, v, maskNone): batch_size q.size(0) # 线性变换并分头 q self.w_q(q).view(batch_size, -1, self.num_heads, self.d_k).transpose(1, 2) k self.w_k(k).view(batch_size, -1, self.num_heads, self.d_k).transpose(1, 2) v self.w_v(v).view(batch_size, -1, self.num_heads, self.d_k).transpose(1, 2) # 计算注意力 attn_output, attn_probs self.scaled_dot_product_attention(q, k, v, mask) # 合并多头 attn_output attn_output.transpose(1, 2).contiguous().view(batch_size, -1, self.d_model) # 最终线性变换 output self.w_o(attn_output) return output, attn_probs3. 编码器与解码器架构有了核心组件后我们可以构建完整的编码器和解码器结构。3.1 编码器层的实现每个编码器层包含一个自注意力机制和一个前馈网络class EncoderLayer(nn.Module): def __init__(self, d_model: int, num_heads: int, d_ff: int, dropout: float 0.1): super().__init__() self.self_attn MultiHeadAttention(d_model, num_heads, dropout) self.feed_forward nn.Sequential( nn.Linear(d_model, d_ff), nn.ReLU(), nn.Linear(d_ff, d_model) ) self.norm1 nn.LayerNorm(d_model) self.norm2 nn.LayerNorm(d_model) self.dropout nn.Dropout(dropout) def forward(self, x, maskNone): # 自注意力子层 attn_output, _ self.self_attn(x, x, x, mask) x x self.dropout(attn_output) x self.norm1(x) # 前馈网络子层 ff_output self.feed_forward(x) x x self.dropout(ff_output) x self.norm2(x) return x3.2 解码器层的实现解码器层比编码器更复杂包含三种注意力机制class DecoderLayer(nn.Module): def __init__(self, d_model: int, num_heads: int, d_ff: int, dropout: float 0.1): super().__init__() self.self_attn MultiHeadAttention(d_model, num_heads, dropout) self.cross_attn MultiHeadAttention(d_model, num_heads, dropout) self.feed_forward nn.Sequential( nn.Linear(d_model, d_ff), nn.ReLU(), nn.Linear(d_ff, d_model) ) self.norm1 nn.LayerNorm(d_model) self.norm2 nn.LayerNorm(d_model) self.norm3 nn.LayerNorm(d_model) self.dropout nn.Dropout(dropout) def forward(self, x, encoder_output, src_maskNone, tgt_maskNone): # 自注意力子层目标序列的自注意力 attn_output, _ self.self_attn(x, x, x, tgt_mask) x x self.dropout(attn_output) x self.norm1(x) # 交叉注意力子层查询来自解码器键值来自编码器 cross_output, _ self.cross_attn(x, encoder_output, encoder_output, src_mask) x x self.dropout(cross_output) x self.norm2(x) # 前馈网络子层 ff_output self.feed_forward(x) x x self.dropout(ff_output) x self.norm3(x) return x4. 完整模型组装与训练现在我们将所有组件组合成完整的Transformer模型并实现训练流程。4.1 模型组装class Transformer(nn.Module): def __init__(self, src_vocab_size: int, tgt_vocab_size: int, d_model: int 512, num_heads: int 8, num_layers: int 6, d_ff: int 2048, dropout: float 0.1, max_seq_len: int 100): super().__init__() # 词嵌入层 self.src_embed nn.Embedding(src_vocab_size, d_model) self.tgt_embed nn.Embedding(tgt_vocab_size, d_model) # 位置编码 self.pos_encoding PositionalEncoding(d_model, max_seq_len) # 编码器 self.encoder_layers nn.ModuleList([ EncoderLayer(d_model, num_heads, d_ff, dropout) for _ in range(num_layers) ]) # 解码器 self.decoder_layers nn.ModuleList([ DecoderLayer(d_model, num_heads, d_ff, dropout) for _ in range(num_layers) ]) # 输出层 self.linear nn.Linear(d_model, tgt_vocab_size) self.dropout nn.Dropout(dropout) def encode(self, src, src_mask): src_embedded self.dropout(self.pos_encoding(self.src_embed(src))) for layer in self.encoder_layers: src_embedded layer(src_embedded, src_mask) return src_embedded def decode(self, tgt, encoder_output, src_mask, tgt_mask): tgt_embedded self.dropout(self.pos_encoding(self.tgt_embed(tgt))) for layer in self.decoder_layers: tgt_embedded layer(tgt_embedded, encoder_output, src_mask, tgt_mask) return tgt_embedded def forward(self, src, tgt, src_maskNone, tgt_maskNone): encoder_output self.encode(src, src_mask) decoder_output self.decode(tgt, encoder_output, src_mask, tgt_mask) output self.linear(decoder_output) return output4.2 训练流程实现以下是简化的训练循环包含学习率调度和梯度裁剪def train_model(model, train_loader, val_loader, epochs10, lr0.0001): device torch.device(cuda if torch.cuda.is_available() else cpu) model model.to(device) criterion nn.CrossEntropyLoss(ignore_index1) # 忽略padding索引 optimizer torch.optim.Adam(model.parameters(), lrlr, betas(0.9, 0.98), eps1e-9) scheduler torch.optim.lr_scheduler.StepLR(optimizer, step_size1, gamma0.95) for epoch in range(epochs): model.train() total_loss 0 for batch in train_loader: src, tgt batch.src.to(device), batch.tgt.to(device) # 创建掩码 src_mask (src ! 1).unsqueeze(1).unsqueeze(2) # padding索引为1 tgt_mask (tgt ! 1).unsqueeze(1).unsqueeze(2) seq_len tgt.size(1) nopeak_mask torch.triu(torch.ones(1, seq_len, seq_len) 1).transpose(1, 2) nopeak_mask nopeak_mask.float().to(device) tgt_mask tgt_mask nopeak_mask optimizer.zero_grad() # 前向传播 output model(src, tgt[:, :-1], src_mask, tgt_mask[:, :-1, :-1]) # 计算损失 loss criterion(output.reshape(-1, output.size(-1)), tgt[:, 1:].reshape(-1)) # 反向传播 loss.backward() torch.nn.utils.clip_grad_norm_(model.parameters(), max_norm1.0) optimizer.step() total_loss loss.item() scheduler.step() avg_loss total_loss / len(train_loader) print(fEpoch {epoch1}, Loss: {avg_loss:.4f}) # 验证 model.eval() val_loss evaluate(model, val_loader, criterion, device) print(fValidation Loss: {val_loss:.4f}) return model4.3 推理与翻译示例训练完成后我们可以用以下函数进行翻译def translate(model, sentence, src_vocab, tgt_vocab, max_len50): model.eval() device next(model.parameters()).device # 预处理输入句子 tokens [src_vocab[bos]] [src_vocab[token] for token in en_tokenizer(sentence)] [src_vocab[eos]] src torch.LongTensor(tokens).unsqueeze(0).to(device) src_mask (src ! 1).unsqueeze(1).unsqueeze(2) # 初始化目标序列 tgt torch.LongTensor([[tgt_vocab[bos]]]).to(device) for i in range(max_len): tgt_mask (tgt ! 1).unsqueeze(1).unsqueeze(2) seq_len tgt.size(1) nopeak_mask torch.triu(torch.ones(1, seq_len, seq_len) 1).transpose(1, 2) nopeak_mask nopeak_mask.float().to(device) tgt_mask tgt_mask nopeak_mask output model(src, tgt, src_mask, tgt_mask) next_token output.argmax(-1)[:, -1].unsqueeze(1) tgt torch.cat([tgt, next_token], dim1) if next_token.item() tgt_vocab[eos]: break # 将索引转换为单词 translated [tgt_vocab.lookup_token(idx) for idx in tgt.squeeze().tolist()] return .join(translated[1:-1]) # 去掉bos和eos在实际项目中我发现在解码阶段使用束搜索(beam search)比贪婪解码能获得更流畅的翻译结果。此外添加标签平滑(label smoothing)也能有效缓解模型过度自信的问题。

相关文章:

别再死记硬背Transformer了!用PyTorch手把手实现一个简易翻译模型(附完整代码)

用PyTorch从零构建Transformer翻译模型:代码驱动的深度学习实践 如果你已经读过Transformer的论文或看过相关教程,却依然对如何实现这个革命性架构感到迷茫,那么这篇文章正是为你准备的。我们将避开繁琐的理论推导,直接进入代码层…...

在Taotoken平台查看与导出详细API调用日志用于分析与审计

在Taotoken平台查看与导出详细API调用日志用于分析与审计 1. 访问审计日志功能 Taotoken平台为团队管理员提供了完整的API调用日志记录功能。要访问审计日志,首先登录Taotoken控制台,在左侧导航栏中找到「审计日志」或「API日志」菜单项。该功能通常位…...

魔兽地图开发者的救星:w3x2lni格式转换工具完全指南

魔兽地图开发者的救星:w3x2lni格式转换工具完全指南 【免费下载链接】w3x2lni 魔兽地图格式转换工具 项目地址: https://gitcode.com/gh_mirrors/w3/w3x2lni 还在为魔兽地图在不同版本间的兼容性问题头疼吗?是否遇到过辛苦制作的地图无法在其他玩…...

Arduino UNO串口控制DFPlayer Mini播放音乐,这5个常见问题你遇到了吗?(附解决方案)

Arduino UNO与DFPlayer Mini串口音乐播放:5大疑难问题深度解析 当你在工作室里兴奋地连接好Arduino UNO和DFPlayer Mini模块,期待着第一段旋律从扬声器传出时,却发现迎接你的可能是沉默、杂音或是各种意想不到的错误提示。这种挫败感每个创客…...

键盘连击终结者:开源工具KeyboardChatterBlocker让老键盘重获新生

键盘连击终结者:开源工具KeyboardChatterBlocker让老键盘重获新生 【免费下载链接】KeyboardChatterBlocker A handy quick tool for blocking mechanical keyboard chatter. 项目地址: https://gitcode.com/gh_mirrors/ke/KeyboardChatterBlocker 你是否曾经…...

保姆级教程:手把手教你为YOLOv8模型集成GAM注意力模块(附完整代码与配置文件)

深度集成GAM注意力机制到YOLOv8的实战指南 在计算机视觉领域,注意力机制已经成为提升模型性能的关键技术之一。GAM(Global Attention Mechanism)作为一种创新的注意力模块,通过同时考虑通道和空间维度的全局信息交互,能…...

周红伟:Token出海,Agent进场:AI智能体管理元年,他们在复旦管院拆解企业级Agent实战

从“聊天”到“执行”,AI只用了不到一年。以OpenClaw为代表的开源Agent浪潮,正在把AI智能体从“极客玩具”推向真实世界。大模型竞赛的终点,转向谁能率先让Agent嵌入供应链、融入决策流程、深入客户交互,把技术变量真正转化为增长…...

AI Agent与区块链交互:aelf钱包技能包架构设计与实战指南

1. 项目概述:为AI Agent赋能的aelf区块链钱包技能包如果你正在开发一个需要与aelf区块链交互的AI Agent,或者你希望让Claude、Cursor这类AI工具能帮你管理数字资产、查询链上数据,那么你很可能需要一套标准化的“技能”。portkey/eoa-agent-s…...

AIVectorMemory:为AI编程助手构建持久化语义记忆系统

1. 项目概述 如果你还在用 CLAUDE.md 或者 MEMORY.md 这种 Markdown 文件来给你的 AI 编程助手当“脑子”,那我得说,是时候升级一下你的装备了。我过去一年里,几乎每天都在和 Cursor、Claude Code、Kiro 这些 AI IDE 打交道,最…...

球磨机实际应用序列之机械合金化:突破传统熔炼的创新材料制备技术

1 概述机械合金化是通过机械球磨实现粉末合金化的关键技术,是材料制备领域广泛应用的合金化方法之一。该工艺以机械驱动力诱导粉末发生固相反应,突破传统熔炼的熔点限制与平衡相图约束,可制备常规方法难以获得的新型合金与固溶体材料。2 球磨…...

开源LLM监控平台llm.report部署指南:成本分析与提示词优化

1. 项目概述:一个被“放弃”的开源宝藏 最近在整理自己的AI应用项目时,发现OpenAI的API账单有点“失控”了。月初设定的预算,到了月中就频频告警,仔细一看,全是各种调试、测试请求产生的费用,真正有价值的调…...

ARM Cortex-A开发工具链与Linux系统构建实战

1. ARM Cortex-A开发工具链深度解析在嵌入式Linux开发领域,工具链的选择直接影响着最终系统的性能和开发效率。作为一位长期从事ARM平台开发的工程师,我见证了工具链技术的演进历程,也积累了丰富的实战经验。本文将系统性地剖析ARM Cortex-A系…...

深入理解与实战应用:Python爬虫中的Robots.txt规范与urllib.robotparser完全指南

目录 第一章:robots.txt协议的来龙去脉 1.1 历史渊源:1994年的一个夏天 1.2 robots.txt的基本语法 1.3 robots.txt的局限性 第二章:urllib.robotparser模块完全解析 2.1 模块概览与设计哲学 2.2 基础用法示例 2.3 核心API详解 2.4 实战:构建robots.txt检查器 第三…...

BetterNCM插件管理器:一键安装网易云音乐插件的终极解决方案

BetterNCM插件管理器:一键安装网易云音乐插件的终极解决方案 【免费下载链接】BetterNCM-Installer 一键安装 Better 系软件 项目地址: https://gitcode.com/gh_mirrors/be/BetterNCM-Installer BetterNCM插件管理器是一款专为网易云音乐PC客户端设计的Rust原…...

告别手动点点点:用Python脚本一键启动CANoe里的TestModule和vTESTstudio测试

告别手动点点点:用Python脚本一键启动CANoe里的TestModule和vTESTstudio测试 每天重复打开CANoe工程、加载配置、启动测试模块的操作,是否让你感到效率低下?对于车载网络测试工程师来说,这些重复性手动操作不仅耗时,还…...

新手福音:用快马ai生成带详细注释的freertos学习项目,轻松入门实时操作系统

作为一个刚接触嵌入式开发的菜鸟,最近被导师要求学习FreeRTOS。面对任务调度、队列、信号量这些概念,我完全是一头雾水。好在发现了InsCode(快马)平台,用它生成的带详细注释的FreeRTOS示例项目,让我这个小白终于摸到了门道。下面分…...

Cisco交换机802.1x认证的‘安全后路’怎么留?详解认证失败后的VLAN分配与ACL控制

Cisco交换机802.1x认证的柔性安全策略:认证失败后的智能处理方案 在企业网络安全管理中,802.1x认证作为接入控制的核心技术,其部署细节往往决定了安全性与用户体验的平衡点。许多工程师在配置时过于关注认证成功后的流程,却忽略了…...

限流与配额:防止 AI “疯狂执行”

网罗开发(小红书、快手、视频号同名)大家好,我是 展菲,目前在上市企业从事人工智能项目研发管理工作,平时热衷于分享各种编程领域的软硬技能知识以及前沿技术,包括iOS、前端、Harmony OS、Java、Python等方…...

基于深度学习的OpenClaw验证码识别:从CRNN原理到工程部署实战

1. 项目概述:一个专为“OpenClaw”设计的验证码识别引擎 最近在做一个自动化流程的项目,遇到了一个叫“OpenClaw”的验证码系统,图形扭曲、字符粘连,常规的OCR工具完全失效。为了解决这个问题,我花了不少时间研究&…...

如何用5分钟彻底解决Windows风扇噪音问题:FanControl终极配置指南

如何用5分钟彻底解决Windows风扇噪音问题:FanControl终极配置指南 【免费下载链接】FanControl.Releases This is the release repository for Fan Control, a highly customizable fan controlling software for Windows. 项目地址: https://gitcode.com/GitHub_…...

终极鸣潮体验优化指南:3个简单技巧让你的游戏性能飞升!

终极鸣潮体验优化指南:3个简单技巧让你的游戏性能飞升! 【免费下载链接】WaveTools 🧰鸣潮工具箱 项目地址: https://gitcode.com/gh_mirrors/wa/WaveTools 还在为《鸣潮》的帧率锁定和画质模糊而烦恼吗?鸣潮工具箱&#x…...

RSSHub Radar:5分钟实现智能RSS订阅管理的浏览器扩展解决方案

RSSHub Radar:5分钟实现智能RSS订阅管理的浏览器扩展解决方案 【免费下载链接】RSSHub-Radar 🧡 Browser extension that simplifies finding and subscribing RSS and RSSHub 项目地址: https://gitcode.com/gh_mirrors/rs/RSSHub-Radar 在信息爆…...

如何实现设计到动画的无缝转换:AEUX开源插件的完整指南

如何实现设计到动画的无缝转换:AEUX开源插件的完整指南 【免费下载链接】AEUX Editable After Effects layers from Sketch artboards 项目地址: https://gitcode.com/gh_mirrors/ae/AEUX 在当今数字设计领域,从静态设计到动态动画的转换一直是设…...

掌握OR-Tools:5个步骤从零开始构建运筹优化解决方案

掌握OR-Tools:5个步骤从零开始构建运筹优化解决方案 【免费下载链接】or-tools Googles Operations Research tools: 项目地址: https://gitcode.com/gh_mirrors/or/or-tools OR-Tools优化工具是Google开源的运筹优化软件套件,专门解决复杂的组合…...

SGM算法调参避坑指南:如何根据你的图像设定P1、P2惩罚值(附Middlebury数据集实测)

SGM算法调参实战:从惩罚参数原理到Middlebury数据集优化策略 在双目立体视觉领域,半全局匹配(SGM)算法因其在精度与效率间的出色平衡,成为工业界和学术界的热门选择。但真正让工程师们夜不能寐的,往往是那些看似简单却暗藏玄机的调…...

从物联网小设备到工业网关:RT-Thread、FreeRTOS、uC/OS-II选型实战指南(附对比表格)

从物联网小设备到工业网关:RT-Thread、FreeRTOS、uC/OS-II选型实战指南 在智能农业监测系统的开发过程中,我们遇到了一个典型困境:如何为不同层级的设备选择合适的实时操作系统?从田间部署的微型土壤传感器到负责数据汇总的4G边缘…...

M9A智能助手如何为《重返未来:1999》玩家每周节省10小时?

M9A智能助手如何为《重返未来:1999》玩家每周节省10小时? 【免费下载链接】M9A 重返未来:1999 小助手 | Assistant For Reverse: 1999 项目地址: https://gitcode.com/gh_mirrors/m9/M9A 每天在《重返未来:1999》中重复点击…...

如何快速实现本地千万级图片秒级搜索:面向新手的完整指南

如何快速实现本地千万级图片秒级搜索:面向新手的完整指南 【免费下载链接】ImageSearch 基于.NET10的本地硬盘千万级图库以图搜图案例Demo和图片exif信息移除小工具分享 项目地址: https://gitcode.com/gh_mirrors/im/ImageSearch 你是否曾在海量图片库中迷失…...

英雄联盟LCU工具箱:League Akari 全面使用指南与实战技巧

英雄联盟LCU工具箱:League Akari 全面使用指南与实战技巧 【免费下载链接】League-Toolkit An all-in-one toolkit for LeagueClient. Gathering power 🚀. 项目地址: https://gitcode.com/gh_mirrors/le/League-Toolkit League Akari是一款基于英…...

如何让Obsidian笔记库拥有AI大脑:obsidian-copilot完全指南

如何让Obsidian笔记库拥有AI大脑:obsidian-copilot完全指南 【免费下载链接】obsidian-copilot THE Copilot in Obsidian 项目地址: https://gitcode.com/gh_mirrors/ob/obsidian-copilot 你是否曾在海量笔记中迷失方向?当需要从数百个文档中提取…...