当前位置: 首页 > article >正文

Transformer架构实战:从零开始手把手实现一个简易版(Python代码示例)

Transformer架构实战从零开始手把手实现一个简易版Python代码示例在人工智能领域Transformer架构已经彻底改变了自然语言处理的游戏规则。不同于传统的循环神经网络RNNTransformer通过自注意力机制实现了并行化处理大幅提升了模型训练和推理的效率。本文将带你从零开始用Python实现一个简化版的Transformer模型深入理解其核心组件的工作原理。1. 环境准备与基础概念在开始编码之前我们需要确保开发环境配置正确。推荐使用Python 3.8和PyTorch 1.10这些版本提供了良好的兼容性和性能优化。pip install torch numpy matplotlibTransformer的核心是自注意力机制它允许模型在处理序列数据时动态地为不同位置的元素分配不同的权重。这种机制模拟了人类阅读时的注意力分配方式——我们会更关注句子中重要的词语而忽略不太相关的部分。自注意力机制与传统RNN的关键区别RNN必须按顺序处理序列难以并行化Transformer可以同时处理整个序列充分利用现代GPU的并行计算能力2. 实现位置编码由于Transformer没有循环结构它需要一种特殊的方式来保留序列中词语的位置信息。这就是位置编码Positional Encoding的作用。import torch import math def positional_encoding(max_len, d_model): position torch.arange(max_len).unsqueeze(1) div_term torch.exp(torch.arange(0, d_model, 2) * (-math.log(10000.0) / d_model)) pe torch.zeros(max_len, d_model) pe[:, 0::2] torch.sin(position * div_term) pe[:, 1::2] torch.cos(position * div_term) return pe # 示例生成长度为10维度为512的位置编码 pe positional_encoding(10, 512) print(pe.shape) # 输出: torch.Size([10, 512])位置编码的关键特性每个位置有唯一的编码编码值在-1到1之间编码可以扩展到任意长度的序列提示位置编码不需要训练它是固定的数学函数生成的。这种设计允许模型处理比训练时更长的序列。3. 实现自注意力机制自注意力是Transformer最核心的组件它由三个主要部分组成查询Query、键Key和值Value。import torch.nn as nn class SelfAttention(nn.Module): def __init__(self, embed_size, heads): super(SelfAttention, self).__init__() self.embed_size embed_size self.heads heads self.head_dim embed_size // heads assert (self.head_dim * heads embed_size), Embed size needs to be divisible by heads self.values nn.Linear(self.head_dim, self.head_dim, biasFalse) self.keys nn.Linear(self.head_dim, self.head_dim, biasFalse) self.queries nn.Linear(self.head_dim, self.head_dim, biasFalse) self.fc_out nn.Linear(heads * self.head_dim, embed_size) def forward(self, values, keys, query, mask): N query.shape[0] value_len, key_len, query_len values.shape[1], keys.shape[1], query.shape[1] # 分割嵌入维度到多个头 values values.reshape(N, value_len, self.heads, self.head_dim) keys keys.reshape(N, key_len, self.heads, self.head_dim) queries query.reshape(N, query_len, self.heads, self.head_dim) values self.values(values) keys self.keys(keys) queries self.queries(queries) # 计算注意力分数 energy torch.einsum(nqhd,nkhd-nhqk, [queries, keys]) if mask is not None: energy energy.masked_fill(mask 0, float(-1e20)) attention torch.softmax(energy / (self.embed_size ** (1/2)), dim3) # 应用注意力权重到值上 out torch.einsum(nhql,nlhd-nqhd, [attention, values]).reshape( N, query_len, self.heads * self.head_dim ) out self.fc_out(out) return out自注意力计算过程详解线性变换将输入分别转换为Q、K、V分数计算计算Q和K的点积然后缩放Softmax归一化得到注意力权重加权求和用注意力权重对V进行加权4. 构建Transformer块现在我们可以将自注意力机制与其他组件组合起来构建完整的Transformer块。class TransformerBlock(nn.Module): def __init__(self, embed_size, heads, dropout, forward_expansion): super(TransformerBlock, self).__init__() self.attention SelfAttention(embed_size, heads) self.norm1 nn.LayerNorm(embed_size) self.norm2 nn.LayerNorm(embed_size) self.feed_forward nn.Sequential( nn.Linear(embed_size, forward_expansion * embed_size), nn.ReLU(), nn.Linear(forward_expansion * embed_size, embed_size) ) self.dropout nn.Dropout(dropout) def forward(self, value, key, query, mask): attention self.attention(value, key, query, mask) # 残差连接和层归一化 x self.dropout(self.norm1(attention query)) forward self.feed_forward(x) out self.dropout(self.norm2(forward x)) return outTransformer块的关键组件多头自注意力捕获序列中不同位置的关系前馈网络增加模型的非线性表达能力层归一化稳定训练过程残差连接缓解梯度消失问题5. 构建完整Transformer模型现在我们可以将所有组件组合起来构建完整的Transformer模型架构。class Transformer(nn.Module): def __init__( self, src_vocab_size, trg_vocab_size, src_pad_idx, trg_pad_idx, embed_size512, num_layers6, forward_expansion4, heads8, dropout0, devicecpu, max_length100 ): super(Transformer, self).__init__() self.encoder_embedding nn.Embedding(src_vocab_size, embed_size) self.decoder_embedding nn.Embedding(trg_vocab_size, embed_size) self.positional_encoding positional_encoding(max_length, embed_size) self.encoder_layers nn.ModuleList( [ TransformerBlock( embed_size, heads, dropoutdropout, forward_expansionforward_expansion ) for _ in range(num_layers) ] ) self.decoder_layers nn.ModuleList( [ TransformerBlock( embed_size, heads, dropoutdropout, forward_expansionforward_expansion ) for _ in range(num_layers) ] ) self.fc_out nn.Linear(embed_size, trg_vocab_size) self.dropout nn.Dropout(dropout) self.src_pad_idx src_pad_idx self.trg_pad_idx trg_pad_idx self.device device def make_src_mask(self, src): src_mask (src ! self.src_pad_idx).unsqueeze(1).unsqueeze(2) return src_mask.to(self.device) def make_trg_mask(self, trg): N, trg_len trg.shape trg_mask torch.tril(torch.ones((trg_len, trg_len))).expand( N, 1, trg_len, trg_len ) return trg_mask.to(self.device) def forward(self, src, trg): src_mask self.make_src_mask(src) trg_mask self.make_trg_mask(trg) src_embedded self.dropout( (self.encoder_embedding(src) self.positional_encoding[:src.shape[1], :]) ) trg_embedded self.dropout( (self.decoder_embedding(trg) self.positional_encoding[:trg.shape[1], :]) ) enc_out src_embedded for layer in self.encoder_layers: enc_out layer(enc_out, enc_out, enc_out, src_mask) dec_out trg_embedded for layer in self.decoder_layers: dec_out layer(enc_out, enc_out, dec_out, trg_mask) out self.fc_out(dec_out) return out6. 模型训练与评估有了完整的Transformer模型后我们需要设置训练流程和评估指标。# 示例训练代码 device torch.device(cuda if torch.cuda.is_available() else cpu) # 假设我们有一些示例数据 src_vocab_size 5000 trg_vocab_size 5000 src_pad_idx 0 trg_pad_idx 0 model Transformer( src_vocab_size, trg_vocab_size, src_pad_idx, trg_pad_idx, devicedevice ).to(device) optimizer torch.optim.Adam(model.parameters(), lr3e-4) criterion nn.CrossEntropyLoss(ignore_indextrg_pad_idx) def train(model, iterator, optimizer, criterion, clip): model.train() epoch_loss 0 for i, batch in enumerate(iterator): src batch.src.to(device) trg batch.trg.to(device) optimizer.zero_grad() output model(src, trg[:, :-1]) output_dim output.shape[-1] output output.contiguous().view(-1, output_dim) trg trg[:, 1:].contiguous().view(-1) loss criterion(output, trg) loss.backward() torch.nn.utils.clip_grad_norm_(model.parameters(), clip) optimizer.step() epoch_loss loss.item() return epoch_loss / len(iterator)训练Transformer模型时需要注意的几个关键点学习率调度使用warmup策略逐步提高学习率梯度裁剪防止梯度爆炸批处理充分利用GPU并行能力正则化适当使用dropout防止过拟合注意在实际应用中你可能需要使用更大的数据集和更长时间的训练才能获得良好的性能。这个简化版主要用于教学目的帮助你理解Transformer的核心原理。7. 模型优化技巧要让Transformer模型发挥最佳性能可以考虑以下优化技巧学习率调度class CustomSchedule(torch.optim.lr_scheduler._LRScheduler): def __init__(self, optimizer, d_model, warmup_steps4000): self.d_model d_model self.warmup_steps warmup_steps super().__init__(optimizer) def get_lr(self): step self._step_count 1 arg1 step ** (-0.5) arg2 step * (self.warmup_steps ** (-1.5)) return [base_lr * (self.d_model ** (-0.5) * min(step ** (-0.5), step * self.warmup_steps ** (-1.5))) for base_lr in self.base_lrs]标签平滑减轻模型对预测结果的过度自信criterion nn.CrossEntropyLoss( ignore_indextrg_pad_idx, label_smoothing0.1 )混合精度训练减少内存占用并加速训练scaler torch.cuda.amp.GradScaler() with torch.cuda.amp.autocast(): output model(src, trg[:, :-1]) loss criterion(output, trg[:, 1:]) scaler.scale(loss).backward() scaler.step(optimizer) scaler.update()8. 实际应用中的注意事项在实际项目中应用Transformer模型时有几个关键因素需要考虑数据预处理文本清洗和标准化适当的tokenization策略子词分割如BPE处理罕见词模型大小选择小数据集4-6层512维度8个头中等数据集6-12层768维度12个头大数据集12-24层1024维度16个头硬件考虑GPU内存限制批处理大小多GPU训练策略混合精度训练推理优化模型量化减小部署体积缓存注意力计算结果束搜索(beam search)参数调优# 示例推理代码 def translate_sentence(sentence, model, device, max_length50): model.eval() # 简单的tokenization tokens sentence.lower().split() tokens [sos] tokens [eos] # 转换为索引 src_indexes [src_field.vocab.stoi[token] for token in tokens] src_tensor torch.LongTensor(src_indexes).unsqueeze(0).to(device) # 创建目标序列 trg_indexes [trg_field.vocab.stoi[sos]] for i in range(max_length): trg_tensor torch.LongTensor(trg_indexes).unsqueeze(0).to(device) with torch.no_grad(): output model(src_tensor, trg_tensor) pred_token output.argmax(2)[:,-1].item() trg_indexes.append(pred_token) if pred_token trg_field.vocab.stoi[eos]: break trg_tokens [trg_field.vocab.itos[i] for i in trg_indexes] return trg_tokens[1:]

相关文章:

Transformer架构实战:从零开始手把手实现一个简易版(Python代码示例)

Transformer架构实战:从零开始手把手实现一个简易版(Python代码示例) 在人工智能领域,Transformer架构已经彻底改变了自然语言处理的游戏规则。不同于传统的循环神经网络(RNN),Transformer通过自…...

Artifactory-oos私有Maven仓库:从零搭建到企业级组件托管实战

1. 为什么企业需要私有Maven仓库 记得去年我们团队接手一个大型金融项目时,遇到了一个典型问题:十几个模块都在重复使用相同的支付SDK,每次版本更新都要手动替换所有项目的jar包。更糟的是,某个同事不小心用了旧版本导致线上事故。…...

EC20模块实战:quectel-CM启动流程全解析(附常见问题排查)

EC20模块深度实战:quectel-CM启动全流程与高阶问题排查指南 在物联网设备开发中,EC20模块凭借其稳定的4G通信能力和丰富的功能接口,已成为工业级应用的常青树。而quectel-CM作为其核心连接管理工具,启动过程中的每个环节都直接影响…...

Unity WebGL中文输入难题破解:InputField全屏输入与跨平台适配方案

1. Unity WebGL中文输入难题解析 第一次用Unity开发WebGL项目时,我就被InputField的中文输入问题坑惨了。明明在编辑器里测试好好的,打包成WebGL后死活打不出中文,只能输入英文和数字。后来才发现这是Unity WebGL平台的"祖传问题"…...

C/C++中的u8、u16、u32数据类型实战指南:嵌入式开发中的高效应用

1. 嵌入式开发中的数据类型选择困境 第一次接触STM32开发时,我被各种u8、u16、u32数据类型搞得晕头转向。记得当时要处理一个温度传感器的数据,随手用了int类型,结果发现内存占用比预期大了整整一倍。这种经历让我深刻认识到,在嵌…...

【GitHub项目推荐--SimpleKernel:面向 AI 辅助学习的现代化操作系统内核】⭐⭐⭐

项目简介 SimpleKernel 是由 Simple-XX 团队维护的一个开源操作系统内核项目。与传统教学内核不同,它采用 Interface-Driven(接口驱动)​ 的设计理念,旨在利用 AI 辅助进行操作系统内核的学习与开发。项目采用 C23 编写&#xff…...

基于Pixel-to-Space的视频空间反演技术在智慧军营中的应用研究

《基于Pixel-to-Space的视频空间反演技术在智慧军营中的应用研究》副标题:面向三维感知与认知决策的空间计算体系构建发布单位:镜像视界(浙江)科技有限公司一、研究背景与问题提出随着智慧军营与智能化作战体系建设的不断推进&…...

新一代智慧军营空间智能底座:视频反演驱动的全域感知与作战中枢系统

《新一代智慧军营空间智能底座:视频反演驱动的全域感知与作战中枢系统》副标题:基于 Pixel-to-Space 的空间认知引擎与战术智能基础设施发布单位:镜像视界(浙江)科技有限公司一、执行摘要随着智能化作战体系与数字化军…...

空间重构驱动的智慧军营:三维感知 × 行为认知 × 智能指挥体系

《空间重构驱动的智慧军营:三维感知 行为认知 智能指挥体系》副标题:基于 Pixel-to-Space 的军营空间认知与战术决策引擎发布单位:镜像视界(浙江)科技有限公司一、执行摘要在智能化作战体系持续演进的背景下&#xf…...

使用Python实现Blender与虚幻引擎PSK/PSA格式自动化处理方案

使用Python实现Blender与虚幻引擎PSK/PSA格式自动化处理方案 【免费下载链接】io_scene_psk_psa A Blender plugin for importing and exporting Unreal PSK and PSA files 项目地址: https://gitcode.com/gh_mirrors/io/io_scene_psk_psa 在现代游戏开发工作流中&#…...

从视频到空间:面向智慧军营的三维作战感知与认知决策平台

《从视频到空间:面向智慧军营的三维作战感知与认知决策平台》副标题:基于 Pixel-to-Space 的空间认知引擎与战术智能体系发布单位:镜像视界(浙江)科技有限公司一、执行摘要随着信息化战争向智能化战争演进,…...

从‘看WP’到‘写WP’:我的CTF逆向入门踩坑实录与BUUCTF前16题保姆级复盘

从‘看WP’到‘写WP’:我的CTF逆向入门踩坑实录与BUUCTF前16题保姆级复盘 第一次接触CTF逆向时,面对满屏的汇编代码和陌生的工具界面,我完全不知所措。和大多数新手一样,我开始疯狂搜索别人的解题报告(Writeup&#xf…...

Fiverr实验室突破:AI代理开发实现食谱式简化流程

这项由Fiverr实验室领导的研究发表于2026年的arXiv平台,论文编号为arXiv:2603.08806v1,研究团队开发了一种全新的AI代理开发方法。有兴趣深入了解的读者可以通过该编号查询完整论文。现在的AI助手开发就像在没有食谱的情况下做一道复杂菜肴——你知道想要…...

半导体材料中的晶体结构解析:从NaCl到金刚石,工程师必备知识

半导体材料中的晶体结构解析:从NaCl到金刚石,工程师必备知识 在半导体工业的精密制造中,晶体结构如同建筑的地基,决定了材料的电学、热学和机械性能。当我们拆解一枚芯片时,从硅衬底到氮化镓功率器件,背后都…...

ComfyUI NSFW视频模型下载与部署实战指南:从环境搭建到避坑技巧

最近在尝试部署一些视频生成模型,发现ComfyUI的生态确实很丰富,但NSFW(Not Safe For Work)相关的视频模型在下载和部署过程中会遇到不少坑。经过一番折腾,总算整理出了一套比较顺畅的流程。这篇笔记就记录一下从环境搭…...

RK3588直播机实战:如何用一台设备搞定多机位4K直播(附配置清单)

RK3588直播机实战:如何用一台设备搞定多机位4K直播(附配置清单) 在当今内容创作爆发的时代,专业级直播设备的需求与日俱增,但传统多机位直播系统的高昂成本和复杂操作让许多中小团队望而却步。RK3588直播机的出现&…...

Qt实战:QTableView合并单元格的3种实用场景与完整代码示例

Qt实战:QTableView合并单元格的3种实用场景与完整代码示例 在Qt开发中,表格数据展示是常见的需求场景。当我们需要展示具有层级关系或分组特性的数据时,合并单元格功能就显得尤为重要。不同于简单的表格布局,合并单元格能够有效提…...

计算机毕业设计:Python房源数据采集分析与智能估价系统 Flask框架 scikit-learn机器学习 可视化 爬虫 SVR算法 房子 房屋 大数据(建议收藏)✅

博主介绍:✌全网粉丝10W,前互联网大厂软件研发、集结硕博英豪成立软件开发工作室,专注于计算机相关专业项目实战6年之久,累计开发项目作品上万套。凭借丰富的经验与专业实力,已帮助成千上万的学生顺利毕业,…...

Neo4j图算法特征工程全攻略:如何为你的GraphSAGE模型注入“专家经验”(以反欺诈为例)

Neo4j图算法特征工程全攻略:如何为你的GraphSAGE模型注入“专家经验”(以反欺诈为例) 在金融风控领域,欺诈用户往往像变色龙一样隐藏在正常用户群体中。传统的结构化数据特征常常难以捕捉这些"伪装者"的蛛丝马迹&#x…...

从Presto到Trino:我们迁移集群踩过的坑与性能对比实录(附436版本调优参数)

从Presto到Trino:迁移实战与性能调优全指南 当我们的数据团队第一次面对从Presto迁移到Trino的决策时,整个团队都充满了疑虑和期待。作为曾经在Presto上运行了数百个关键业务查询的平台,迁移不仅意味着技术栈的变更,更关系到整个数…...

鸣潮高帧率体验完整解决方案:从技术原理到实战优化

鸣潮高帧率体验完整解决方案:从技术原理到实战优化 【免费下载链接】WaveTools 🧰鸣潮工具箱 项目地址: https://gitcode.com/gh_mirrors/wa/WaveTools WaveTools鸣潮工具箱作为开源项目中的佼佼者,为玩家提供了突破游戏帧率限制的完整…...

3步突破:解锁VMware macOS虚拟化的开源方案

3步突破:解锁VMware macOS虚拟化的开源方案 【免费下载链接】unlocker 项目地址: https://gitcode.com/gh_mirrors/unloc/unlocker 当你尝试在VMware中创建macOS虚拟机时,是否遇到过"该操作系统不受支持"的提示?这个常见问…...

Qwen3-4B-Thinking-GGUF开源大模型部署教程:Apache-2.0许可下的企业可用方案

Qwen3-4B-Thinking-GGUF开源大模型部署教程:Apache-2.0许可下的企业可用方案 想找一个开箱即用、性能不错,最关键的是能放心用在商业项目里的开源大模型?今天要聊的 Qwen3-4B-Thinking-GGUF 模型,可能就是你的菜。 它基于通义千…...

DevUI实战指南:10分钟构建企业级Vue后台表单系统

1. 为什么选择DevUI构建企业级表单系统 第一次接触DevUI时,我正为一个电商后台系统焦头烂额。传统UI库的表单在复杂业务场景下就像拼凑的积木,联动校验和异步提交总出问题。直到用DevUI重构了用户管理模块,才发现原来表单开发可以这么高效。 …...

Unity Shader描边别再只用背面膨胀了!这几种方案优缺点和适用场景一次讲清

Unity Shader描边技术深度解析:从基础到高阶实战方案 在游戏开发中,描边效果是提升视觉表现力的重要手段之一。无论是角色高亮、场景交互提示还是特效增强,恰到好处的描边都能显著提升游戏品质。然而,许多开发者往往止步于简单的背…...

从泄漏电流到智能预警:避雷器监测数据的5种高级分析方法(Python示例)

从泄漏电流到智能预警:避雷器监测数据的5种高级分析方法(Python示例) 避雷器作为电力系统的"隐形守护者",其健康状态直接影响电网安全。传统的人工巡检和阈值告警已无法满足智能电网的需求——我们需要的不是简单的数据…...

ESP32固件烧录全攻略:从GPIO0拉低到串口调试的5个关键步骤

ESP32固件烧录实战手册:从硬件准备到成功运行的完整指南 第一次接触ESP32开发板时,那块小小的蓝色电路板让我既兴奋又忐忑。作为物联网项目的核心控制器,ESP32的强大功能毋庸置疑,但如何将编写好的程序成功烧录到芯片中&#xff0…...

移动端适配实战:从rem到vw的平滑迁移指南(附完整代码示例)

移动端适配实战:从rem到vw的平滑迁移指南(附完整代码示例) 在移动互联网时代,多终端适配已成为前端开发的基本功。随着CSS3视口单位(vw/vh)的广泛支持,越来越多的团队开始从传统的rem方案转向更现代的vw方案。本文将深…...

Guacamole前端API详解:从零实现Vue远程桌面控制台

Guacamole前端API详解:从零实现Vue远程桌面控制台 远程桌面技术在现代企业应用中扮演着重要角色,而Guacamole作为一款开源的远程桌面网关,其前端API的实现方式却鲜有详细讨论。本文将深入剖析guacamole-common.js中的核心API,并结…...

快速上手PyTorch 2.5:无需IT支持,自己搞定GPU环境

快速上手PyTorch 2.5:无需IT支持,自己搞定GPU环境 1. 为什么选择PyTorch 2.5 GPU镜像? 作为一名AI开发者或研究人员,最令人沮丧的莫过于花费数小时甚至数天配置开发环境。特别是当需要GPU加速时,CUDA驱动安装、版本兼…...