当前位置: 首页 > article >正文

从零实现 Llama 3:架构拆解与实现细节

本文参考以下英文教程撰写https://pub.towardsai.net/build-your-own-llama-3-architecture-from-scratch-using-pytorch-2ce1ecaa901c第一次看到有人把 Llama 3 从零实现一遍我就知道这件事值得认真做一次。因为只有真正写出来才能体会到每一个设计选择背后的逻辑——为什么 Norm 放在前面而不是后面为什么 KV Cache 只缓存 K 和 V 而不包括 Q为什么 RoPE 要转成复数域再做乘法……这些问题看论文能得到答案但只有自己写代码才能真正把它们变成直觉。这篇笔记的目标是把 Llama 3 的每一个模块从设计动机到公式推导到代码实现都讲清楚。所有技术细节来自 Meta 官方论文所有代码都可以独立运行。我们用 Tiny Shakespeare 数据集来做演示训练因为它足够小能快速看到结果同时又足够有趣让你感受到语言模型在学什么。先看全局Llama 3 本质上是一个标准的 decoder-only transformer但 Meta 在几个关键位置做了精准的替换。架构本身不复杂复杂的是每个替换背后的权衡。整个前向传播可以分成三段输入块 → 解码器块×N层→ 输出块。一、Llama 3 整体架构三个关键杠杆在深入每个模块之前先把 Llama 3 的全貌说清楚。Llama 3 是一个标准的稠密 Transformer 解码器dense Transformer decoder架构论文明确说没有采用混合专家模型MoE主要原因是为了最大化训练稳定性、降低复杂度。论文指出了驱动 Llama 3 性能的三个核心杠杆数据Data预训练语料从 Llama 2 的 1.8T tokens 提升到 15T tokens增幅超过 8 倍。数据质量也大幅提升引入了多轮清洗和领域分类策略。最终的数据配比是约 50% 通用知识、25% 数学与推理、17% 代码、8% 多语言内容。规模Scale最大模型 405B 参数用 3.8×10²⁵ FLOPs 训练是 Llama 2 最大版本的约 50 倍算力。论文按照 Chinchilla 缩放定律推算出计算最优点在 402B 参数和 16.55T tokens最终选择了 405B。复杂度管理Managing ComplexityPost-training 阶段采用 SFT 拒绝采样RS DPO 的组合而非更复杂的 RL 算法理由是复杂算法更难稳定扩展。整个模型在架构层面继承了 Llama 2 的基本骨架但有四处关键修改GQA 的 KV head 数量固定为 8、词表从 32K 扩充到 128K、RoPE base frequency 从 10000 提升到 500000、以及引入跨文档注意力 mask。下面逐一讲。二、输入块Input Block输入块包含三个组件文本/提示词、分词器、嵌入层。1.1 分词器从字符到 token 的映射Llama 3 在生产环境中使用 TikToken 作为分词器这是一个子词subword级分词器词表大小128,000由 100,000 个来自 tiktoken 的基础 token 加上 28,000 个额外的非英语语言 token 组成。相比 Llama 2 的 SentencePiece 分词器词表 32K这次扩充带来的直接收益是压缩率提升——同样一段英文文本Llama 3 平均每个 token 能表示 3.94 个字符而 Llama 2 只有 3.17 个字符。这意味着在相同计算量下Llama 3 能读到更多文本。为什么词表变大能提升压缩率本质是因为更大的词表可以把更多常见词组和词缀直接存成一个 token而不是拆成多个子词。比如 generating 如果本身在词表里就是 1 个 token如果不在可能被拆成 generat 和 ing 两个 token。词表越大整体需要的 token 数越少序列越短注意力计算成本越低。特殊 token 方面Llama 3 定义了|begin_of_text|、|end_of_text|、|eot_id|turn 结束、|start_header_id|和|end_header_id|等这些在对话场景下有重要的结构化作用。在我们的从零实现中用字符级分词器来代替 TikToken目的是让整个 encode/decode 流程完全透明可控with open(tiny_shakespeare.txt, r) as f: data f.read() vocab sorted(list(set(data))) vocab.extend([|begin_of_text|, |end_of_text|, |pad_id|]) vocab_size len(vocab) itos {i: ch for i, ch in enumerate(vocab)} stoi {ch: i for i, ch in enumerate(vocab)} def encode(s): return [stoi[ch] for ch in s] def decode(l): return .join(itos[i] for i in l) token_bos torch.tensor([stoi[|begin_of_text|]], dtypetorch.int, devicedevice) token_eos torch.tensor([stoi[|end_of_text|]], dtypetorch.int, devicedevice) token_pad torch.tensor([stoi[|pad_id|]], dtypetorch.int, devicedevice)1.2 模型超参数在进入解码器之前先把本次实现用到的所有参数集中定义好。注意这里为了让训练快速出结果把 dim 调小到 512实际 Llama 3 8B 的 dim 是 4096dataclass class ModelArgs: dim: int 512 n_layers: int 8 n_heads: int 8 n_kv_heads: int 4 # 对应论文里 GQA 的 KV heads 8这里按比例缩小 vocab_size: int len(vocab) multiple_of: int 256 ffn_dim_multiplier: Optional[float] None norm_eps: float 1e-5 rope_theta: float 500000.0 # 论文把这个值从 10000 提升到了 500000 max_batch_size: int 10 max_seq_len: int 256 epochs: int 2500 log_interval: int 10 device: str cuda if torch.cuda.is_available() else cpu这里有一个细节值得单独说rope_theta 500000.0。这是 Llama 3 相比 Llama 2 的四大架构改动之一。原论文指出将 RoPE 的 base frequency 从默认的 10000 提升到 500000能让模型有效处理更长的上下文——论文引用的研究表明这个值对 32768 长度的上下文有效而在后续的长上下文继续预训练阶段上下文长度进一步扩展到了 128K tokens。三、解码器块Decoder Block解码器块是 Llama 3 的核心。每个解码器块包含六个子组件RMSNorm、RoPE、KV Cache、GQA、FeedForward Network以及把它们组装在一起的 TransformerBlock。我们逐一深入。3.1 RMSNorm比 LayerNorm 更高效的归一化为什么需要归一化embedding 向量在各个维度上的数值范围差异很大直接送入后续计算会导致梯度爆炸或消失训练不稳定。归一化把这些值拉到一个合适的范围让梯度的量级更一致训练更稳。为什么用 RMSNorm 而不是 LayerNorm这是 Llama 系列从第一代就继承下来的选择。LayerNorm 需要计算均值mean和方差variance而 RMSNorm 完全省掉了均值的计算只保留 RMSRoot Mean Square均方根这一步这意味着第一少了均值计算计算开销降低第二没有偏移参数 β参数更少第三论文作者的实验表明性能相当甚至更好。直觉上归一化的关键作用是控制量级scale而均值中心化对这个目标的贡献有限省掉它代价不大。同时要注意Llama 3 采用Pre-Norm 结构即在注意力和前馈网络之前做归一化而不是之后。Pre-Norm 相比 Post-Norm 训练更稳定这一点已经被大量工作证明。class RMSNorm(nn.Module): def __init__(self, dim: int, eps: float 1e-6): super().__init__() self.eps eps self.weight nn.Parameter(torch.ones(dim).to(device)) def _norm(self, x): # x.pow(2).mean(dim-1, keepdimTrue) 是对最后一个维度embedding dim求均方 # rsqrt 1 / sqrt整体就是 x / RMS(x) return x * torch.rsqrt(x.pow(2).mean(dim-1, keepdimTrue) self.eps).to(device) def forward(self, x): # Shape: x[bs, seq, dim] - output[bs, seq, dim] output self._norm(x.float()).type_as(x) return output * self.weight3.2 旋转位置编码RoPE用旋转矩阵编码绝对位置和相对位置问题Transformer 的自注意力机制本质上是置换不变的permutation invariant——把输入序列的顺序打乱注意力得分的模式不会变。但语言显然是有顺序的我爱你和你爱我意思截然不同。所以必须想办法把位置信息注入 embedding。Llama 1/2/3 都用 RoPERotary Positional Encoding而不是原始 Transformer 里的正弦绝对位置编码也不是可学习的绝对位置编码。RoPE 的核心思路是用旋转矩阵对 Q 和 K 的 embedding 进行旋转使得旋转角度正比于 token 的绝对位置。这样做的妙处是注意力分数 Q·K^T 在经过旋转之后只跟两个 token 的相对位置有关与它们的绝对位置无关——因为旋转是线性变换两个旋转矩阵相乘的结果只取决于它们的旋转角度之差即 m-nm 和 n 分别是两个 token 的绝对位置。这就同时实现了绝对位置编码和相对位置感知。RoPE 的数学实现对于位置 m 的 token每对相邻维度 (2i, 2i1) 按角度 m × θᵢ 旋转其中这里 base 就是rope_theta。Llama 3 把 base 从 10000 改成了 500000使得高频分量小 θ变化更缓慢能更好地处理长序列中遥远位置之间的相对关系。旋转操作在实数域是矩阵乘法但在复数域只是点乘所以实现上会先把 embedding 转成复数乘以旋转因子再转回实数def precompute_freqs_cis(dim: int, seq_len: int, theta: float 500000.0): device ModelArgs.device # 计算每对维度的 theta 值θᵢ 1 / (theta^(2i/dim)) freqs 1.0 / (theta ** (torch.arange(0, dim, 2, devicedevice)[:(dim // 2)].float() / dim)) # 计算序列中每个位置 m 的值 t torch.arange(seq_len, dtypetorch.float32, devicedevice) # outer product 得到每个位置每个维度对的旋转角度m × θᵢ freqs torch.outer(t, freqs).to(device) # 转成极坐标形式模1角度freqs即 e^(i·m·θ) 的复数表示 freqs_cis torch.polar(torch.ones_like(freqs).to(device), freqs).to(device) return freqs_cis def reshape_for_broadcast(freqs_cis, x): ndim x.ndim assert freqs_cis.shape (x.shape[1], x.shape[-1]) shape [d if i 1 or i ndim - 1 else 1 for i, d in enumerate(x.shape)] return freqs_cis.view(*shape) def apply_rotary_emb(xq: torch.Tensor, xk: torch.Tensor, freqs_cis: torch.Tensor) - Tuple[torch.Tensor, torch.Tensor]: device ModelArgs.device # 把最后一维两两配对视作复数[bsz, seq_len, n_heads, head_dim/2] xq_ torch.view_as_complex(xq.float().reshape(*xq.shape[:-1], -1, 2)).to(device) xk_ torch.view_as_complex(xk.float().reshape(*xk.shape[:-1], -1, 2)).to(device) freqs_cis reshape_for_broadcast(freqs_cis, xq_) # 复数乘法 旋转然后转回实数并展平最后两维 xq_out torch.view_as_real(xq_ * freqs_cis).flatten(3).to(device) xk_out torch.view_as_real(xk_ * freqs_cis).flatten(3).to(device) return xq_out.type_as(xq), xk_out.type_as(xk)特别注意RoPE 只施加在 Q 和 K 上不施加在 V 上。这是因为 RoPE 的目的是让注意力分数Q·K^T感知相对位置关系而 V 是被注意力加权聚合的值不需要位置旋转。3.3 KV Cache推理时空间换时间的关键优化KV Cache 只在推理阶段启用训练时不需要。理解它的必要性需要先理解自回归生成的过程。在推理时模型每次只生成一个 token但每次生成都要做完整的注意力计算。假设当前已经生成了 t 个 token要生成第 t1 个没有 KV Cache 的情况对 t1 个 token 做完整的 QKV 计算每次都要对所有历史 token 重新计算 K 和 V。这些历史 token 的 K 和 V 在上一步已经算过了完全是重复计算。矩阵乘法的规模是 (t1) × (t1)随着序列变长计算量呈平方增长。有 KV Cache 的情况把每一步计算出来的 K 和 V 存下来下一步直接复用。当前步只需要用最新的一个 Q token与缓存里所有历史的 K、V 做注意力计算矩阵乘法变成 1 × (t1)大幅降低计算量。Q 不需要缓存的原因每一步我们只用当前位置的 Q 来查询它不会被未来的步骤复用。# 在 Attention.__init__ 里初始化 KV Cache self.cache_k torch.zeros((args.max_batch_size, args.max_seq_len, self.n_kv_heads, self.head_dim), deviceargs.device) self.cache_v torch.zeros((args.max_batch_size, args.max_seq_len, self.n_kv_heads, self3.4 分组查询注意力GQA在精度与效率之间找到最优平衡点Llama 3 把 GQA 从 Llama 2 仅在 70B 模型上使用扩展到了所有规模8B、70B、405B都使用且不管模型大小KV heads 数量统一固定为8。理解 GQA 需要先理解三种注意力机制的区别Multi-Head AttentionMHAQ、K、V 各有相同数量的 head比如 32 个。每个 Q head 都有自己对应的独立 K head 和 V head。KV Cache 大小正比于 head 数量。Multi-Query AttentionMQAQ 有多个 head但所有 Q head 共享同一组 K 和 V只有 1 个 KV head 对。KV Cache 大幅缩减但不同 Q head 之间的表示多样性受限可能损失模型质量。Grouped Query AttentionGQA介于两者之间——Q head 分成若干组同一组内的 Q head 共享一对 K/V head。Llama 3 8B 有 32 个 Q heads、8 个 KV heads分组数 32/8 4每 4 个 Q head 共享一对 KV head。这样做的好处是KV Cache 从 MHA 的 32× 降到 8×显存占用大幅减少但保留了 32 个独立的 Q heads注意力表达能力不受影响。实验数据表明这个设计在质量和效率之间取得了很好的平衡。class Attention(nn.Module): def __init__(self, args: ModelArgs): super().__init__() self.args args self.dim args.dim self.n_heads args.n_heads self.n_kv_heads args.n_heads if args.n_kv_heads is None else args.n_kv_heads self.head_dim args.dim // args.n_heads # 每个 KV head 需要被几个 Q head 共享 self.n_rep args.n_heads // args.n_kv_heads # Q 的输出维度n_heads × head_dim # K/V 的输出维度n_kv_heads × head_dim更小 self.wq nn.Linear(self.dim, self.n_heads * self.head_dim, biasFalse, devicedevice) self.wk nn.Linear(self.dim, self.n_kv_heads * self.head_dim, biasFalse, devicedevice) self.wv nn.Linear(self.dim, self.n_kv_heads * self.head_dim, biasFalse, devicedevice) self.wo nn.Linear(self.n_heads * self.head_dim, self.dim, biasFalse, devicedevice) self.cache_k torch.zeros((args.max_batch_size, args.max_seq_len, self.n_kv_heads, self.head_dim), deviceargs.device) self.cache_v torch.zeros((args.max_batch_size, args.max_seq_len, self.n_kv_heads, self.head_dim), deviceargs.device) def forward(self, x: torch.Tensor, start_pos, inference): bsz, seq_len, _ x.shape mask None xq self.wq(x) xk self.wk(x) xv self.wv(x) # reshape 到 [bsz, seq_len, n_heads, head_dim] xq xq.view(bsz, seq_len, self.n_heads, self.head_dim) xk xk.view(bsz, seq_len, self.n_kv_heads, self.head_dim) xv xv.view(bsz, seq_len, self.n_kv_heads, self.head_dim) if inference: # 推理模式启用 KV Cacherope_theta 用论文值 500000 freqs_cis precompute_freqs_cis(dimself.head_dim, seq_lenself.args.max_seq_len * 2, thetaself.args.rope_theta) freqs_cis freqs_cis[start_pos: start_pos seq_len] xq, xk apply_rotary_emb(xq, xk, freqs_cis) self.cache_k self.cache_k.to(xq) self.cache_v self.cache_v.to(xq) self.cache_k[:bsz, start_pos:start_pos seq_len] xk self.cache_v[:bsz, start_pos:start_pos seq_len] xv keys self.cache_k[:bsz, :start_pos seq_len] values self.cache_v[:bsz, :start_pos seq_len] # 把 KV heads 扩展到和 Q heads 一样多以便做矩阵乘法 keys repeat_kv(keys, self.n_rep) values repeat_kv(values, self.n_rep) else: # 训练模式不用 KV Cache直接对整个序列做注意力 freqs_cis precompute_freqs_cis(dimself.head_dim, seq_lenself.args.max_seq_len, thetaself.args.rope_theta) xq, xk apply_rotary_emb(xq, xk, freqs_cis) keys repeat_kv(xk, self.n_rep) values repeat_kv(xv, self.n_rep) # 因果 mask上三角填 -inf防止当前 token 看到未来 token mask torch.full((seq_len, seq_len), float(-inf), deviceself.args.device) mask torch.triu(mask, diagonal1).to(self.args.device) # Transpose to [bsz, n_heads, seq_len, head_dim] xq xq.transpose(1, 2) keys keys.transpose(1, 2) values values.transpose(1, 2) # 注意力分数 Q·K^T / √d_k然后 softmax然后加权 V scores torch.matmul(xq, keys.transpose(2, 3)).to(self.args.device) / math.sqrt(self.head_dim) if mask is not None: scores scores mask scores F.softmax(scores.float(), dim-1).type_as(xq) output torch.matmul(scores, values).to(self.args.device) # 把所有 head 的输出拼回来[bsz, n_heads, seq_len, head_dim] - [bsz, seq_len, dim] output output.transpose(1, 2).contiguous().view(bsz, seq_len, -1) return self.wo(output) def repeat_kv(x: torch.Tensor, n_rep: int) - torch.Tensor: 把 KV head 的数量扩展到和 Q head 一样多 bsz, seq_len, n_kv_heads, head_dim x.shape if n_rep 1: return x return ( x[:, :, :, None, :] .expand(bsz, seq_len, n_kv_heads, n_rep, head_dim) .reshape(bsz, seq_len, n_kv_heads * n_rep, head_dim) )这里有一个很容易忽视的细节因果 mask 在训练时是必须的因为训练时整个序列一次性并行处理模型必须被阻止看到未来的 token。而在推理时因为 KV Cache 的存在每次只处理一个新 token天然没有未来信息泄露的问题所以 mask 不需要了。3.5 前馈网络FeedForward with SwiGLU前馈网络在每个注意力块之后负责对每个 token 的表示做非线性变换让模型能学到更复杂的特征。Llama 3 使用SwiGLU激活函数而非原始 Transformer 里的 ReLU 或 GeLU。SwiGLU 全称是 Swish-Gated Linear Unit它的前馈计算公式是其中 SiLU(x) x · σ(x)σ 是 sigmoid⊙ 是逐元素乘法。和标准的两层 FFN 不同SwiGLU 用了三个线性变换矩阵W₁、W₂、W₃其中 W₁ 的输出经过 SiLU 激活后作为门对 W₃ 的输出进行过滤。为什么 SwiGLU 比 ReLU 好关键在于 ReLU 在负数区域输出全为 0hard gate而 SwiGLU 在负数区域有平滑的非零输出梯度更连续也保留了一定的负数信息。这种软门控机制让模型的表达能力更强同时训练更稳定。由于多了 W₃ 这个矩阵FFN 的参数量相比 ReLU 版本多了约 50%。为了保持总参数量不变Llama 3 对隐层维度做了调整论文里用的隐层维度计算公式是int(2 * hidden_dim / 3)并取 256 的倍数class FeedForward(nn.Module): def __init__(self, dim: int, hidden_dim: int, multiple_of: int, ffn_dim_multiplier: Optional[float]): super().__init__() self.dim dim # 按 Meta 的隐层维度计算公式先缩到 2/3再取 256 的整数倍 hidden_dim int(2 * hidden_dim / 3) if ffn_dim_multiplier is not None: hidden_dim int(ffn_dim_multiplier * hidden_dim) hidden_dim multiple_of * ((hidden_dim multiple_of - 1) // multiple_of) self.w1 nn.Linear(self.dim, hidden_dim, biasFalse, devicedevice) # gate 路径 self.w2 nn.Linear(hidden_dim, self.dim, biasFalse, devicedevice) # 输出投影 self.w3 nn.Linear(self.dim, hidden_dim, biasFalse, devicedevice) # value 路径 def forward(self, x): # SwiGLU: SiLU(W₁x) ⊙ W₃x然后过 W₂ return self.w2(F.silu(self.w1(x)) * self.w3(x))从论文给出的具体数字来看Llama 3 8B 的 FFN 隐层维度是 14,33670B 是 28,672405B 是 53,248。3.6 解码器块TransformerBlock把所有子模块组装起来一个完整的 TransformerBlock 按照以下顺序执行输入 x 先过 RMSNormPre-Norm再进入注意力层注意力输出和原始 x 做残差连接Residual Connection残差连接结果再过 RMSNorm进入 FFNFFN 输出再次做残差连接用公式写就是残差连接是 Transformer 稳定深度训练的关键。没有残差连接梯度在穿过几十层之后会彻底消失完全无法训练。残差连接提供了一条高速路让梯度可以直接从输出层流回输入层。class TransformerBlock(nn.Module): def __init__(self, args: ModelArgs): super().__init__() self.attention_norm RMSNorm(dimargs.dim, epsargs.norm_eps) self.attention Attention(args) self.ff_norm RMSNorm(dimargs.dim, epsargs.norm_eps) self.feedforward FeedForward(args.dim, 4 * args.dim, args.multiple_of, args.ffn_dim_multiplier) def forward(self, x, start_pos, inference): # Pre-Norm Attention Residual h x self.attention(self.attention_norm(x), start_pos, inference) # Pre-Norm FFN Residual out h self.feedforward(self.ff_norm(h)) return outLlama 3 的三个规模对应不同的解码器层数8B 模型有 32 层70B 模型有 80 层405B 模型有 126 层。每一层的结构完全相同只是宽度dim不同。四、输出块Output Block与完整模型所有解码器块处理完之后最后的 hidden states 流入输出块先过一次 RMSNorm再过一个线性层把 embedding 维度映射到词表大小输出 logits。logits 的每个维度对应词表里的一个 tokensoftmax 之后就是模型预测下一个 token 的概率分布。训练时把 logits 和真实 target labels 传入交叉熵损失函数反向传播更新所有参数。推理时从 logits 对应的概率分布中采样得到下一个生成的 token。class Transformer(nn.Module): def __init__(self, params: ModelArgs): super().__init__() self.params params # 输入层token id - embedding vector self.tok_embeddings nn.Embedding(params.vocab_size, params.dim) # 解码器堆叠n_layers 个 TransformerBlock self.layers nn.ModuleList() for layer_id in range(params.n_layers): self.layers.append(TransformerBlock(argsparams)) # 输出层 self.norm RMSNorm(params.dim, epsparams.norm_eps) self.output nn.Linear(params.dim, params.vocab_size, biasFalse) def forward(self, x, start_pos0, targetsNone): # x: [bsz, seq_len] - h: [bsz, seq_len, dim] h self.tok_embeddings(x) inference targets is None for layer in self.layers: h layer(h, start_pos, inference) h self.norm(h) # h: [bsz, seq_len, dim] - logits: [bsz, seq_len, vocab_size] logits self.output(h).float() loss None if targets is not None: loss F.cross_entropy(logits.view(-1, self.params.vocab_size), targets.view(-1)) return logits, loss五、训练在训练流程上我们用 80% 的数据做训练10% 做验证10% 做测试。每次随机采样一个 batch输入是从|begin_of_text|开始的序列目标是把这个序列向右移动一位即每个位置预测下一个 token。dataset torch.tensor(encode(data), dtypetorch.int).to(ModelArgs.device) def get_dataset_batch(data, split, args: ModelArgs): seq_len args.max_seq_len batch_size args.max_batch_size device args.device train data[:int(0.8 * len(data))] val data[int(0.8 * len(data)): int(0.9 * len(data))] test data[int(0.9 * len(data)):] batch_data train if split val: batch_data val if split test: batch_data test ix torch.randint(0, len(batch_data) - seq_len - 3, (batch_size,)).to(device) # x|begin_of_text| 正文前 seq_len-1 个字符 x torch.stack([torch.cat([token_bos, batch_data[i:i seq_len - 1]]) for i in ix]).long().to(device) # y正文后 seq_len-1 个字符 |end_of_text|即 x 右移一位 y torch.stack([torch.cat([batch_data[i 1:i seq_len], token_eos]) for i in ix]).long().to(device) return x, y torch.no_grad() def evaluate_loss(model, args: ModelArgs): out {} model.eval() for split in [train, val]: losses [] for _ in range(10): xb, yb get_dataset_batch(dataset, split, args) _, loss model(xxb, targetsyb) losses.append(loss.item()) out[split] np.mean(losses) model.train() return out def train(model, optimizer, args: ModelArgs): epochs args.epochs log_interval args.log_interval device args.device losses [] start_time time.time() for epoch in range(epochs): optimizer.zero_grad() xs, ys get_dataset_batch(dataset, train, args) xs xs.to(device) ys ys.to(device) logits, loss model(xxs, targetsys) loss.backward() optimizer.step() if epoch % log_interval 0: batch_time time.time() - start_time x evaluate_loss(model, args) losses [x] print(fEpoch {epoch} | val loss {x[val]:.3f} | Time {batch_time:.3f}) start_time time.time() print(validation loss: , losses[-1][val]) return pd.DataFrame(losses).plot() model Transformer(ModelArgs).to(ModelArgs.device) optimizer torch.optim.Adam(model.parameters()) train(model, optimizer, ModelArgs)在 Google Colab 的免费 GPU 上2500 个 epoch 大约需要 10 分钟最终 validation loss 在 2.19 左右。这个数字并不算低原因是我们只用了 Tiny Shakespeare 这个小数据集且没有做任何超参数调优。真实的 Llama 3 在 15T tokens 上训练差距是数量级的。六、推理推理的核心是自回归生成autoregressive generation每次预测一个 token把这个 token 追加到输入序列再继续预测下一个直到生成了最大长度或遇到结束符。采样策略使用Top-pNucleusSampling按概率降序排列所有 token找到累积概率刚好超过 p 的最小集合只从这个集合里随机采样。这比直接取最大概率贪心解码生成的文本更有多样性比完全随机采样又更有质量保证。Temperature 参数控制分布的尖锐程度temperature 1 会让分布更集中更保守temperature 1 会让分布更平坦更随机。def generate(model, prompts: str, params: ModelArgs, max_gen_len: int 500, temperature: float 0.6, top_p: float 0.9): bsz 1 prompt_tokens token_bos.tolist() encode(prompts) assert len(prompt_tokens) params.max_seq_len total_len min(len(prompt_tokens) max_gen_len, params.max_seq_len) tokens torch.full((bsz, total_len), fill_valuetoken_pad.item(), dtypetorch.long, deviceparams.device) tokens[:, :len(prompt_tokens)] torch.tensor(prompt_tokens, dtypetorch.long, deviceparams.device) input_text_mask tokens ! token_pad.item() prev_pos 0 for cur_pos in range(1, total_len): with torch.no_grad(): logits, _ model(xtokens[:, prev_pos:cur_pos], start_posprev_pos) if temperature 0: probs torch.softmax(logits[:, -1] / temperature, dim-1) next_token sample_top_p(probs, top_p) else: next_token torch.argmax(logits[:, -1], dim-1) next_token next_token.reshape(-1) # 如果当前位置是 prompt 的一部分不要覆盖它 next_token torch.where(input_text_mask[:, cur_pos], tokens[:, cur_pos], next_token) tokens[:, cur_pos] next_token prev_pos cur_pos if tokens[:, cur_pos] token_pad.item() and next_token token_eos.item(): break output_tokens, output_texts [], [] for i, toks in enumerate(tokens.tolist()): if token_eos.item() in toks: eos_idx toks.index(token_eos.item()) toks toks[:eos_idx] output_tokens.append(toks) output_texts.append(decode(toks)) return output_tokens, output_texts def sample_top_p(probs, p): # 按概率降序排列 probs_sort, prob_idx torch.sort(probs, dim-1, descendingTrue) probs_sum torch.cumsum(probs_sort, dim-1) # 找到累积概率超过 p 的截断点把截断点之后的概率置 0 mask probs_sum - probs_sort p probs_sort[mask] 0.0 # 重归一化后采样 probs_sort.div_(probs_sort.sum(dim-1, keepdimTrue)) next_token torch.multinomial(probs_sort, num_samples1) next_token torch.gather(prob_idx, -1, next_token) return next_token prompts Consider you what services he has done output_tokens, output_texts generate(model, prompts, ModelArgs) output_texts output_texts[0].replace(|begin_of_text|, ) print(output_texts)附论文里的那些工程细节上面的代码实现覆盖了 Llama 3 的核心架构但论文里还有几个工程细节值得单独讲一讲它们是真实生产训练里性能的关键保障也是理解为什么 Llama 3 能在这么大的规模上稳定训练的原因。1. 文档级别的 attention mask论文提到在标准预训练阶段效果有限但在长上下文继续预训练时至关重要。它的作用是当多个文档被拼接成一个长序列时阻止不同文档之间的 token 互相 attend。否则文档 A 的最后一个 token 会看到文档 B 的第一个 token这会注入本不应该存在的上下文关系污染长距离依赖的学习。2. 长上下文预训练Llama 3 的标准预训练上下文长度是 8K tokens但 Llama 3.1 的 405B 模型最终支持 128K tokens 的上下文。这不是一步到位的而是分六个阶段逐步扩展从 8K 到 128K用了约 8000 亿 token 的继续预训练来让模型适应。3. 退火Annealing在预训练的最后阶段把学习率线性退火到 0同时把高质量的数学和代码数据的权重调高。论文发现退火让 8B 模型在 GSM8k 上的准确率提升了 24%在 MATH 上提升了 6.4%。这提示我们数据质量在预训练的最后阶段比数量更重要。4. 缩放定律实验Meta 在确定 405B 这个参数规模之前跑了大量从 40M 到 16B 参数的小模型实验建立了 IsoFLOPs 曲线用幂律关系 N*(C) A × C^α 来预测最优参数规模。拟合出 α0.53A0.29据此推算出 3.8×10²⁵ FLOPs 对应的计算最优点约在 402B 参数最终选择了 405B。这是把科学方法引入工程决策的典型例子。论文链接https://arxiv.org/abs/2407.21783Meta 官方代码https://github.com/meta-llama/llama3

相关文章:

从零实现 Llama 3:架构拆解与实现细节

本文参考以下英文教程撰写:https://pub.towardsai.net/build-your-own-llama-3-architecture-from-scratch-using-pytorch-2ce1ecaa901c 第一次看到有人把 Llama 3 从零实现一遍,我就知道这件事值得认真做一次。因为只有真正写出来,才能体会…...

大麦网抢票自动化:从技术原理到实战落地的全方位指南

大麦网抢票自动化:从技术原理到实战落地的全方位指南 【免费下载链接】DamaiHelper 大麦网演唱会演出抢票脚本。 项目地址: https://gitcode.com/gh_mirrors/dama/DamaiHelper 问题引入:抢票困境与技术破局 在热门演出票务竞争日益激烈的当下&am…...

突破视频下载壁垒:yt-dlp-gui的全场景应用指南

突破视频下载壁垒:yt-dlp-gui的全场景应用指南 【免费下载链接】yt-dlp-gui Windows GUI for yt-dlp 项目地址: https://gitcode.com/gh_mirrors/yt/yt-dlp-gui 在数字化时代,视频内容已成为信息传递与知识获取的重要载体。然而,多数平…...

解锁浏览器超能力:Greasy Fork用户脚本平台完全指南

解锁浏览器超能力:Greasy Fork用户脚本平台完全指南 【免费下载链接】greasyfork An online repository of user scripts. 项目地址: https://gitcode.com/gh_mirrors/gr/greasyfork 认知启蒙:重新认识浏览器脚本的价值 还在为浏览器功能不足烦恼…...

亲测实用!6款覆盖全职业阶段的专业简历模板平台合集

很多人找工作的时候,都会卡在简历制作这一步。大家想要做出专业的简历,需要靠谱的专业简历模板平台,需要能直接参考的全行业简历案例,还需要能通过企业筛选的ATS适配简历模板。我整理了6款亲测好用的简历模板平台,国内…...

Stable Yogi Leather-Dress-Collection与智能车结合:生成个性化汽车内饰皮革方案

Stable Yogi Leather-Dress-Collection与智能车结合:生成个性化汽车内饰皮革方案 想象一下,你正坐在一辆智能车的展厅里,面前的巨大屏幕不是用来播放宣传片的,而是一个属于你的“数字裁缝铺”。你用手指轻轻滑动,选择…...

AI for Science新引擎:一文读懂符号计算的核心原理与实战指南

AI for Science新引擎:一文读懂符号计算的核心原理与实战指南 引言 在人工智能(AI)与科学研究(Science)深度融合的浪潮中,符号计算正从幕后走向台前,成为解决科学发现、工程优化等复杂问题的关…...

Phi-3-mini-128k-instruct处理复杂数据结构:算法题解答与优化展示

Phi-3-mini-128k-instruct处理复杂数据结构:算法题解答与优化展示 最近在尝试用一些轻量级的模型来辅助解决编程问题,特别是算法和数据结构这块。很多人觉得大模型只能写写简单的脚本,处理复杂逻辑可能不太行。正好手头有Phi-3-mini-128k-in…...

AI for Science新范式:当深度学习“求解”偏微分方程

AI for Science新范式:当深度学习“求解”偏微分方程 引言 在科学与工程的心脏地带,偏微分方程(PDE)如同描述万物规律的密码。从流体的舞蹈到宇宙的演化,传统数值方法(如有限元、有限体积法)虽…...

OpenClaw内存优化:千问3.5-35B-A3B-FP8在8GB设备的运行技巧

OpenClaw内存优化:千问3.5-35B-A3B-FP8在8GB设备的运行技巧 1. 为什么需要内存优化 当我第一次尝试在8GB内存的MacBook Pro上运行千问3.5-35B-A3B-FP8模型时,系统几乎立即崩溃了。这让我意识到,想要在资源有限的设备上运行大型语言模型&…...

实践之漏洞挖掘(弱口令)

前言:经过我的不懈努力,也是挖到了弱口令,嘻嘻,学校的,虽然没有泄露什么隐私,但是我交了要更新就是学校的漏洞,过不过都没关系,没过我下次就找有隐私的后台再交嘻嘻正题:…...

资源嗅探革新性工具:猫抓让网页资源获取变得前所未有的简单

资源嗅探革新性工具:猫抓让网页资源获取变得前所未有的简单 【免费下载链接】cat-catch 猫抓 浏览器资源嗅探扩展 / cat-catch Browser Resource Sniffing Extension 项目地址: https://gitcode.com/GitHub_Trending/ca/cat-catch 你是否曾经遇到过想要保存网…...

JiYuTrainer:重构教学控制逻辑的突破型技术方案

JiYuTrainer:重构教学控制逻辑的突破型技术方案 【免费下载链接】JiYuTrainer 极域电子教室防控制软件, StudenMain.exe 破解 项目地址: https://gitcode.com/gh_mirrors/ji/JiYuTrainer 构建多维度控制体系 💡 技术要点:通过内核级驱…...

降低OpenClaw Token消耗的三大实战策略,省钱后随便花,再也不用担心不够了

让AI“跑得更快、花得更少”:OpenClaw降本增效的终极实战手册 想象一下,你雇佣了一位才华横溢、但收费高昂的顶尖顾问。每次咨询,你都不厌其烦地把过去一整年的会议记录、所有项目文档、甚至茶水间的闲聊纪要都一股脑儿塞给他,然…...

Label Studio ML Backend架构设计与高可用机器学习服务实现深度解析

Label Studio ML Backend架构设计与高可用机器学习服务实现深度解析 【免费下载链接】label-studio-ml-backend Configs and boilerplates for Label Studios Machine Learning backend 项目地址: https://gitcode.com/gh_mirrors/la/label-studio-ml-backend Label Stu…...

告别学术阅读障碍:重新定义PDF翻译体验

告别学术阅读障碍:重新定义PDF翻译体验 【免费下载链接】PDFMathTranslate PDF scientific paper translation with preserved formats - 基于 AI 完整保留排版的 PDF 文档全文双语翻译,支持 Google/DeepL/Ollama/OpenAI 等服务,提供 CLI/GUI…...

从理论到模型:HFSS仿真平面发夹滤波器的关键步骤与参数优化

1. HFSS仿真前的理论准备 在开始HFSS仿真之前,我们需要先完成一些理论计算工作。这就像盖房子要先画图纸一样,没有理论指导的仿真就像无头苍蝇。我刚开始做滤波器设计时就犯过这个错误,直接上手建模,结果调参调到怀疑人生。 平面发…...

LongCat-Image-Editn V2效果展示:看AI如何精准将图中的猫变成狗

LongCat-Image-Editn V2效果展示:看AI如何精准将图中的猫变成狗 1. 效果惊艳开场:当AI成为你的修图助手 想象一下这样的场景:你拍了一张完美的照片,构图、光线、背景都无可挑剔,唯一的遗憾是照片里的主角——你的猫咪…...

PyTorch 2.8深度学习镜像实战教程:RTX 4090D + CUDA 12.4一键部署指南

PyTorch 2.8深度学习镜像实战教程:RTX 4090D CUDA 12.4一键部署指南 1. 镜像概述与环境准备 1.1 为什么选择这个镜像 如果你正在寻找一个开箱即用的深度学习环境,这个基于RTX 4090D 24GB显卡和CUDA 12.4优化的PyTorch 2.8镜像可能是理想选择。它专为…...

QPdf:Qt生态下的PDF渲染技术深度解析与现代应用实践

QPdf:Qt生态下的PDF渲染技术深度解析与现代应用实践 【免费下载链接】qpdf PDF viewer widget for Qt 项目地址: https://gitcode.com/gh_mirrors/qpd/qpdf 在Qt应用开发中,PDF文档处理一直是个技术痛点。传统方案要么依赖平台原生组件导致跨平台…...

开启iphone的墙纸玻璃效果

要开启 iPhone 的墙纸“玻璃效果”,需注意:苹果并未在 iOS 中提供名为“玻璃效果”的独立开关,但通过 “液态玻璃”(Liquid Glass)设计风格 和 “空间场景”壁纸 等功能,可实现类似视觉效果。以下是基于最新公开资料的操作指南&am…...

5分钟快速上手:AI视频生成工具完整指南

5分钟快速上手:AI视频生成工具完整指南 【免费下载链接】auto-video-generateor 自动视频生成器,给定主题,自动生成解说视频。用户输入主题文字,系统调用大语言模型生成故事或解说的文字,然后进一步调用语音合成接口生…...

可能是综合性能最强的PCIe 5.0 SSD!铠侠EXCERIA PRO G2 2TB评测:AIDA64线性写入全程不掉速

一、前言:铠侠首款旗舰级PCIe 5.0 SSD 可能很多读者会疑惑,作为存储领域的一线巨头,在PCIe 5.0时代,为什么铠侠迟迟没有推出旗舰级SSD产品! 这主要是因为,早期的PCIe 5.0 SSD主控功耗极高(超过10W)&#xf…...

8大核心功能解决网盘下载难题:Online-disk-direct-link-download-assistant完全指南

8大核心功能解决网盘下载难题:Online-disk-direct-link-download-assistant完全指南 【免费下载链接】Online-disk-direct-link-download-assistant 一个基于 JavaScript 的网盘文件下载地址获取工具。基于【网盘直链下载助手】修改 ,支持 百度网盘 / 阿…...

别再硬用Search API了!Qdrant纯Payload查询的正确姿势:Scroll API实战与性能调优

别再硬用Search API了!Qdrant纯Payload查询的正确姿势:Scroll API实战与性能调优 最近在重构一个电商后台系统时,我发现团队里不少工程师都在用Qdrant的Search API做纯Payload字段查询——比如按订单状态筛选数据、根据商品标签过滤结果集。这…...

我们这些程序员在人工智能时代注定要失败吗?(一位穷困潦倒的计算机科学系学生)

Reddit上有个帖子让我看了心里一紧。 标题很简单,却像一把刀:"Are we devs doomed in AI world? A broke CS student."(我们在AI世界注定要失败吗?一位穷困潦倒的计算机科学系学生) 发帖人没留下名字,就写了一句话:学编程是为了改变命运,结果发现命运被AI改…...

B站硬核会员试炼的AI自动答题工具:从痛点到实践的完整指南

B站硬核会员试炼的AI自动答题工具:从痛点到实践的完整指南 【免费下载链接】bili-hardcore bilibili 硬核会员 AI 自动答题脚本,直接调用 B 站 API,非 OCR 实现 项目地址: https://gitcode.com/gh_mirrors/bi/bili-hardcore 一、痛点剖…...

Pyodide 0.26:WebAssembly Python的突破性升级

Pyodide 0.26:WebAssembly Python的突破性升级 【免费下载链接】pyodide Pyodide is a Python distribution for the browser and Node.js based on WebAssembly 项目地址: https://gitcode.com/gh_mirrors/py/pyodide 在WebAssembly技术快速发展的今天&…...

从“技术迷宫“到“一键导航“:OpCore-Simplify如何让黑苹果配置变得像搭积木一样简单

从"技术迷宫"到"一键导航":OpCore-Simplify如何让黑苹果配置变得像搭积木一样简单 【免费下载链接】OpCore-Simplify A tool designed to simplify the creation of OpenCore EFI 项目地址: https://gitcode.com/GitHub_Trending/op/OpCore-S…...

ArknightsGameResource:模块化游戏资源库与标准化数据解析技术指南

ArknightsGameResource:模块化游戏资源库与标准化数据解析技术指南 【免费下载链接】ArknightsGameResource 明日方舟客户端素材 项目地址: https://gitcode.com/gh_mirrors/ar/ArknightsGameResource ArknightsGameResource项目为《明日方舟》游戏开发者提供…...