当前位置: 首页 > article >正文

机器翻译模型笔记

机器翻译学习笔记(简体中文)

1. 任务概述

  • 目标:将英文句子翻译成简体中文。

  • 示例

    • 输入:Tom is a student.

    • 输出:汤姆是一个学生。

  • 框架:Seq2Seq(序列到序列)模型。

2. 数据预处理

2.1 下载数据

  • 数据集:TED2020(英文-简体中文对齐的平行语料)。

  • 代码

    # 下载TED2020数据集的压缩文件
    # - wget命令用于从指定URL下载文件
    # - -O选项指定下载文件的保存路径和名称
    # - 目的:获取训练所需的双语平行语料
    !wget https://github.com/yuhsinchan/ML2022-HW5Dataset/releases/download/v1.0.2/ted2020.tgz -O ./DATA/rawdata/ted2020/ted2020.tgz# 解压下载的压缩文件
    # - tar命令用于解压.tgz文件
    # - -xvf选项表示解压、显示详细信息、并指定文件
    # - -C选项指定解压的目标目录
    # - 目的:将压缩文件解压到指定的数据目录中
    !tar -xvf ./DATA/rawdata/ted2020/ted2020.tgz -C ./DATA/rawdata/ted2020/
  • 通俗总结:这段代码就像从网上下载一个装满英文和中文翻译句子的压缩包,然后把它解压到一个文件夹里。这些句子是模型训练的“教材”,里面有成千上万的英文句子和对应的中文翻译,方便模型学习怎么把英文变成中文。

2.2 清洗和规范化

  • 功能:移除噪声、统一格式。

  • 代码

    def clean_s(s, lang):"""清洗和规范化文本数据,根据语言类型执行不同的预处理操作。参数:s (str): 输入的原始文本字符串。lang (str): 语言类型,'en' 表示英文,'zh' 表示中文。返回:str: 经过清洗和规范化后的文本字符串。"""if lang == 'en':  # 如果语言是英文,进入英文文本处理逻辑# 移除括号及其内部内容# - 使用正则表达式匹配括号及其内容:\( 表示左括号,\) 表示右括号# - [^()]* 表示括号内不包含括号的任意字符,* 表示匹配零次或多次# - re.sub 将匹配的括号及其内容替换为空字符串# - 目的:移除补充说明或注释性内容,减少翻译时的噪声s = re.sub(r"\([^()]*\)", "", s)# 移除连字符 '-'# - 使用 replace 方法将所有的连字符替换为空字符串# - 目的:简化英文文本,减少不必要的符号,提升文本一致性s = s.replace('-', '')# 在标点符号前后添加空格# - ([.,;!?()\"]) 是一个正则表达式,匹配常见的英文标点符号# - r' \1 ' 表示在匹配的标点前后各添加一个空格,\1 代表匹配的标点本身# - re.sub 执行替换操作# - 目的:确保标点与单词分离,便于分词工具识别,提升模型处理能力s = re.sub('([.,;!?()\"])', r' \1 ', s)elif lang == 'zh':  # 如果语言是中文,进入中文文本处理逻辑# 将全角字符转换为半角字符# - strQ2B 是一个外部函数,专门用于将全角字符(如全角标点)转为半角字符# - 目的:统一字符编码,减少模型处理不同编码字符时的复杂度s = strQ2B(s)# 移除括号及其内部内容# - 与英文处理相同,使用正则表达式匹配并移除括号及其内容# - \( 和 \) 分别表示左右括号,[^()]* 表示括号内的任意非括号字符# - 目的:去除补充说明或非主要内容,保持翻译数据的干净s = re.sub(r"\([^()]*\)", "", s)# 移除空格和特定字符,并统一引号样式# - replace(' ', '') 移除所有空格,因为中文文本通常不依赖空格分词# - replace('—', '') 移除中文破折号,清理不必要的符号# - replace('“', '"') 和 replace('”', '"') 将中文引号替换为英文引号# - 目的:规范化中文文本,去除冗余符号,保持格式一致性s = s.replace(' ', '').replace('—', '').replace('“', '"').replace('”', '"')# 在中文标点符号前后添加空格# - ([。,;!?()\"~「」]) 匹配常见的中文标点符号# - r' \1 ' 在匹配的标点前后添加空格,\1 表示匹配的标点本身# - re.sub 执行替换操作# - 目的:将标点与文字分离,便于分词和模型处理标点s = re.sub('([。,;!?()\"~「」])', r' \1 ', s)# 规范化文本中的空格# - s.strip() 移除字符串首尾的多余空格# - split() 将字符串按空格分割成列表,自动移除连续空格# - ' '.join() 将列表元素用单个空格连接成字符串# - 目的:确保单词或字符之间只有一个空格,统一文本格式return ' '.join(s.strip().split())
  • 通俗总结:这段代码就像个“文本清洁工”,专门整理英文和中文句子。英文句子会把括号里的备注删掉、去掉连字符、给标点两边加空格;中文句子会把全角标点改成半角、删掉空格和破折号、统一引号样式,也给标点加空格。最后把多余的空格都清理掉,让句子看起来整洁统一,方便模型理解。

2.3 移除不良数据

  • 功能:根据长度和比例移除不合适的句子对。

  • 代码

    def clean_corpus(prefix, l1, l2, ratio=9, max_len=1000, min_len=1):# 打开源语言和目标语言的文件进行读取# - f'{prefix}.{l1}' 是源语言文件路径,l1 是源语言代码(如 'en')# - f'{prefix}.{l2}' 是目标语言文件路径,l2 是目标语言代码(如 'zh')with open(f'{prefix}.{l1}', 'r') as l1_in_f, open(f'{prefix}.{l2}', 'r') as l2_in_f:# 打开清洗后的源语言和目标语言文件进行写入# - f'{prefix}.clean.{l1}' 是清洗后的源语言文件路径# - f'{prefix}.clean.{l2}' 是清洗后的目标语言文件路径with open(f'{prefix}.clean.{l1}', 'w') as l1_out_f, open(f'{prefix}.clean.{l2}', 'w') as l2_out_f:# 逐行读取源语言文件for s1 in l1_in_f:s1 = s1.strip()  # 去除源语言句子首尾的空白字符s2 = l2_in_f.readline().strip()  # 读取目标语言的对应行并去除首尾空白s1 = clean_s(s1, l1)  # 调用clean_s函数清洗源语言文本s2 = clean_s(s2, l2)  # 调用clean_s函数清洗目标语言文本s1_len = len_s(s1, l1)  # 计算源语言文本的长度(len_s是一个外部函数)s2_len = len_s(s2, l2)  # 计算目标语言文本的长度# 跳过过短的句子对# - 如果任一语言的长度小于min_len(默认1),则跳过if s1_len < min_len or s2_len < min_len:continue# 跳过过长的句子对# - 如果任一语言的长度大于max_len(默认1000),则跳过if s1_len > max_len or s2_len > max_len:continue# 跳过长度比例不合理的句子对# - 如果源语言和目标语言的长度比例超过ratio(默认9),则跳过# - 防止翻译对长度差异过大,影响模型训练if s1_len / s2_len > ratio or s2_len / s1_len > ratio:continue# 将清洗后的句子写入对应文件# - print默认会在末尾添加换行符,file参数指定输出文件print(s1, file=l1_out_f)print(s2, file=l2_out_f)
  • 通俗总结:这个代码就像个“句子筛选员”,把英文和中文的句子对一对一检查。先用 clean_s 清理句子,然后看看句子长度:太短(少于1个词)没啥用,太长(超过1000个词)模型吃不消,扔掉;如果英文和中文句子长度差太多(比如一个9倍长),可能翻译不靠谱,也扔掉。合格的句子对就保存到新文件里,留给模型用。

2.4 分词(Subword Units)

  • 工具:SentencePiece。

  • 代码

    import sentencepiece as spm
    # 训练SentencePiece模型
    # - input: 指定训练数据文件,多个文件用逗号分隔
    # - model_prefix: 模型文件的前缀,用于保存训练好的模型和词汇表
    # - vocab_size: 设置词汇表大小(8000),控制subword的数量
    # - model_type: 模型类型,'unigram' 是一种基于unigram语言模型的subword分割方法
    # - 目的:生成subword模型,将文本分割成子词单元,减少词汇量并处理未登录词
    spm.SentencePieceTrainer.train(input=','.join([f'{prefix}/train.clean.{src_lang}', f'{prefix}/valid.clean.{src_lang}']),model_prefix=f'{prefix}/spm8000',model_type='unigram'
    )
  • 通俗总结:这段代码用 SentencePiece 工具给句子“切词”。它不按完整单词切,而是把句子拆成小片段(像“playing”可能拆成“play”和“ing”),生成一个8000个“片段”的词典。这样模型不用记太多单词,也能处理没见过的词,训练起来更省力。

2.5 二值化(Binarize)

  • 工具:fairseq。

  • 代码

    # 使用fairseq的preprocess命令对数据进行二值化
    # - --source-lang: 指定源语言(如 'en')
    # - --target-lang: 指定目标语言(如 'zh')
    # - --trainpref, --validpref, --testpref: 指定训练、验证、测试数据集的前缀路径
    # - --destdir: 指定二值化数据的保存目录
    # - --joined-dictionary: 使用联合词典,即源语言和目标语言共享同一个词典
    # - --workers: 指定并行处理的worker数量,提升处理速度
    # - 目的:将文本数据转换为fairseq可直接使用的二进制格式,加速数据加载和训练
    !python -m fairseq_cli.preprocess \--source-lang {src_lang} --target-lang {tgt_lang} \--trainpref {prefix}/train --validpref {prefix}/valid --testpref {prefix}/test \--destdir {binpath} --joined-dictionary --workers 2
  • 通俗总结:这个代码把清理好的句子从文本变成“机器专用格式”(二进制文件),就像把书本内容压缩成电脑能快速读的代码。fairseq 工具会把训练、验证、测试数据都处理好,英文和中文用同一个词典,加快模型读取数据的速度。

3. 模型定义

3.1 RNN Seq2Seq 模型

3.1.1 编码器(RNNEncoder)
  • 功能:将输入句子转为嵌入向量,使用双向 GRU 处理。

  • 代码

    class RNNEncoder(FairseqEncoder):def __init__(self, args, dictionary, embed_tokens):# 初始化编码器,继承自FairseqEncodersuper().__init__(dictionary)self.embed_tokens = embed_tokens  # 词嵌入层,将token ID映射为嵌入向量self.embed_dim = args.encoder_embed_dim  # 嵌入向量的维度,从参数中获取self.hidden_dim = args.encoder_ffn_embed_dim  # GRU隐藏层的维度,从参数中获取self.num_layers = args.encoder_layers  # GRU的层数,从参数中获取self.dropout_in_module = nn.Dropout(args.dropout)  # 在输入嵌入后应用的dropout层self.rnn = nn.GRU(self.embed_dim,  # 输入维度,即词嵌入的维度self.hidden_dim,  # 隐藏层维度,控制GRU的容量self.num_layers,  # GRU层数,决定深度dropout=args.dropout,  # dropout概率,防止过拟合batch_first=False,  # 输入形状为(seq_len, batch, embed_dim)bidirectional=True  # 双向GRU,捕捉前后上下文信息)self.dropout_out_module = nn.Dropout(args.dropout)  # 在GRU输出后应用的dropout层self.padding_idx = dictionary.pad()  # 从词典中获取padding token的IDdef forward(self, src_tokens, **unused):# 前向传播函数,处理源语言输入bsz, seqlen = src_tokens.size()  # 获取batch size和序列长度# 将token ID转换为嵌入向量x = self.embed_tokens(src_tokens)  # 输出形状:(batch, seq_len, embed_dim)x = self.dropout_in_module(x)  # 在嵌入向量上应用dropout,减少过拟合x = x.transpose(0, 1)  # 调整为(seq_len, batch, embed_dim),适配GRU输入# 初始化隐藏状态# - new_zeros生成全零张量,形状为(2 * num_layers, batch, hidden_dim)# - 2 * num_layers 因为是双向GRU,每层有正向和反向h0 = x.new_zeros(2 * self.num_layers, bsz, self.hidden_dim)# 通过双向GRU处理输入# - 输出x: (seq_len, batch, hidden_dim*2),包含每个时间步的隐藏状态# - final_hiddens: (num_layers*2, batch, hidden_dim),每层的最终隐藏状态x, final_hiddens = self.rnn(x, h0)outputs = self.dropout_out_module(x)  # 在GRU输出上应用dropout# 处理双向GRU的隐藏状态# - combine_bidir是自定义方法,将正向和反向隐藏状态合并final_hiddens = self.combine_bidir(final_hiddens, bsz)# 创建padding mask,标记padding位置# - eq比较token是否等于padding_idx,t()转置为(seq_len, batch)encoder_padding_mask = src_tokens.eq(self.padding_idx).t()# 返回GRU输出、最终隐藏状态和padding maskreturn outputs, final_hiddens, encoder_padding_mask
  • 通俗总结:这段代码建了一个“句子理解器”(编码器)。它把英文句子里的每个词变成数字向量,再用双向 GRU(一种记忆力强的神经网络)从头到尾、从尾到头读一遍句子,记住每个词的上下文信息。GRU 像个聪明的笔记员,能记住重要的东西,忘了不重要的,最后输出句子的“精华信息”和一个标记,告诉模型哪些是填充的无效部分。

3.1.2 解码器(RNNDecoder)
  • 功能:结合注意力机制和单向 GRU 生成翻译。

  • 代码

    class RNNDecoder(FairseqIncrementalDecoder):def __init__(self, args, dictionary, embed_tokens):# 初始化解码器,继承自FairseqIncrementalDecodersuper().__init__(dictionary)self.embed_tokens = embed_tokens  # 词嵌入层,将token ID映射为嵌入向量self.embed_dim = args.decoder_embed_dim  # 嵌入向量的维度,从参数中获取self.hidden_dim = args.decoder_ffn_embed_dim  # GRU隐藏层的维度,从参数中获取self.num_layers = args.decoder_layers  # GRU的层数,从参数中获取self.dropout_in_module = nn.Dropout(args.dropout)  # 输入dropout层self.rnn = nn.GRU(self.embed_dim,  # 输入维度,即词嵌入的维度self.hidden_dim,  # 隐藏层维度,控制GRU容量self.num_layers,  # GRU层数,决定深度dropout=args.dropout,  # dropout概率,防止过拟合batch_first=False,  # 输入形状(seq_len, batch, embed_dim)bidirectional=False  # 单向GRU,按序生成输出)self.attention = AttentionLayer(...)  # 注意力机制层,动态关注编码器输出self.dropout_out_module = nn.Dropout(args.dropout)  # 输出dropout层# 如果隐藏层维度与嵌入维度不一致,添加投影层if self.hidden_dim != self.embed_dim:self.project_out_dim = nn.Linear(self.hidden_dim, self.embed_dim)else:self.project_out_dim = None# 输出投影层,将隐藏状态映射到词汇表大小self.output_projection = nn.Linear(self.embed_dim, len(dictionary))def forward(self, prev_output_tokens, encoder_out, incremental_state=None, **unused):# 前向传播函数,生成目标语言输出# 从编码器获取输出、隐藏状态和padding maskencoder_outputs, encoder_hiddens, encoder_padding_mask = encoder_outif incremental_state is not None and len(incremental_state) > 0:# 如果有增量状态(用于推理),只取上一个时间步的输出prev_output_tokens = prev_output_tokens[:, -1:]cache_state = self.get_incremental_state(incremental_state, "cached_state")prev_hiddens = cache_state["prev_hiddens"]  # 获取缓存的隐藏状态else:# 训练时或推理的第一个时间步,使用编码器的隐藏状态prev_hiddens = encoder_hiddensbsz, seqlen = prev_output_tokens.size()  # 获取batch size和序列长度# 将token ID转换为嵌入向量x = self.embed_tokens(prev_output_tokens)  # 输出形状:(batch, seq_len, embed_dim)x = self.dropout_in_module(x)  # 在嵌入向量上应用dropoutx = x.transpose(0, 1)  # 调整为(seq_len, batch, embed_dim),适配GRU# 应用注意力机制,动态关注编码器输出if self.attention is not None:x, attn = self.attention(x, encoder_outputs, encoder_padding_mask)# 通过单向GRU处理输入# - 输出x: (seq_len, batch, hidden_dim),每个时间步的隐藏状态# - final_hiddens: (num_layers, batch, hidden_dim),最终隐藏状态x, final_hiddens = self.rnn(x, prev_hiddens)x = self.dropout_out_module(x)  # 在GRU输出上应用dropout# 如果需要,投影到嵌入维度if self.project_out_dim is not None:x = self.project_out_dim(x)# 投影到词汇表大小,生成logitsx = self.output_projection(x)  # 输出形状:(seq_len, batch, vocab_size)x = x.transpose(1, 0)  # 调整为(batch, seq_len, vocab_size),适配后续处理# 更新增量状态(用于推理)cache_state = {"prev_hiddens": final_hiddens}self.set_incremental_state(incremental_state, "cached_state", cache_state)return x, None  # 返回logits和额外信息(此处为None)
  • 通俗总结:这个“句子生成器”(解码器)负责把英文句子的“精华信息”翻译成中文。它先把已生成的中文词变成数字向量,然后用单向 GRU(一个能记住之前内容的神经网络)一步步生成新词。注意力机制像个“参考指南”,让解码器随时回头看英文句子的关键部分,确保翻译准确。最后,它输出每个词的概率,决定下一个中文词是什么。

3.1.3 Seq2Seq 模型
  • 功能:整合编码器和解码器。

  • 代码

    class Seq2Seq(FairseqEncoderDecoderModel):def __init__(self, args, encoder, decoder):# 初始化Seq2Seq模型,继承自FairseqEncoderDecoderModelsuper().__init__(encoder, decoder)self.args = args  # 保存参数对象,包含模型配置self.encoder = encoder  # 编码器实例self.decoder = decoder  # 解码器实例def forward(self, src_tokens, src_lengths, prev_output_tokens, return_all_hiddens: bool = True):# 前向传播函数,处理完整翻译过程# 编码器处理源语言输入# - src_tokens: 源语言token序列# - src_lengths: 每个序列的实际长度encoder_out = self.encoder(src_tokens, src_lengths=src_lengths, return_all_hiddens=return_all_hiddens)# 解码器处理目标语言输入和编码器输出# - prev_output_tokens: 前一时间步的目标语言token序列# - encoder_out: 编码器的输出logits, extra = self.decoder(prev_output_tokens,encoder_out=encoder_out,src_lengths=src_lengths,return_all_hiddens=return_all_hiddens)return logits, extra  # 返回logits(预测概率)和额外信息
  • 通俗总结:这个代码把编码器和解码器“组装”成一个完整的翻译机器。编码器先读懂英文句子,提取重要信息;解码器再根据这些信息生成中文句子。整个过程就像一个人先听懂一句外语,然后用自己的语言复述出来。

3.2 Transformer 模型

3.2.1 修改模型定义
  • 代码

    from fairseq.models.transformer import TransformerEncoder, TransformerDecoder
    # 使用Transformer编码器
    # - args: 模型参数
    # - src_dict: 源语言词典
    # - encoder_embed_tokens: 源语言词嵌入层
    encoder = TransformerEncoder(args, src_dict, encoder_embed_tokens)
    # 使用Transformer解码器
    # - args: 模型参数
    # - tgt_dict: 目标语言词典
    # - decoder_embed_tokens: 目标语言词嵌入层
    decoder = TransformerDecoder(args, tgt_dict, decoder_embed_tokens)
    # 构建Seq2Seq模型,整合编码器和解码器
    model = Seq2Seq(args, encoder, decoder)
  • 通俗总结:这段代码把模型从 RNN 升级成 Transformer。Transformer 像个更聪明的翻译员,不用按顺序读句子,而是同时看整个句子,通过“自注意力”机制快速抓住重点。编码器和解码器还是干老本行,但用 Transformer 的方式更快更准。

3.2.2 调整超参数
  • 代码

    arch_args = Namespace(encoder_embed_dim=512,  # 编码器嵌入维度,控制输入表示的容量encoder_ffn_embed_dim=2048,  # 编码器前馈网络维度,增加模型复杂度encoder_layers=4,  # 编码器层数,决定深度decoder_embed_dim=512,  # 解码器嵌入维度,与编码器对齐decoder_ffn_embed_dim=2048,  # 解码器前馈网络维度,增加容量decoder_layers=4,  # 解码器层数,决定深度encoder_attention_heads=8,  # 编码器注意力头数,提升并行处理能力decoder_attention_heads=8,  # 解码器注意力头数,提升并行处理能力dropout=0.3  # dropout概率,防止过拟合
    )
  • 通俗总结:这段代码是给 Transformer 模型“调参数”,就像调汽车引擎。把词向量的维度设为512,网络层数设为4层,注意力头数设为8个,增加模型的“脑容量”。还加了30%的 dropout,防止模型“死记硬背”,让它更灵活。

4. 训练和优化

4.1 损失函数(Label Smoothing)

  • 功能:加入平滑项防止过拟合。

  • 代码

    class LabelSmoothedCrossEntropyCriterion(nn.Module):def __init__(self, smoothing, ignore_index=None, reduce=True):# 初始化Label Smoothed Cross Entropy损失函数super().__init__()self.smoothing = smoothing  # 平滑参数,控制平滑程度self.ignore_index = ignore_index  # 忽略的标签ID(通常是padding)self.reduce = reduce  # 是否对损失进行reduce(求和或平均)def forward(self, lprobs, target):# 前向传播,计算损失# 如果target维度比lprobs少1维,调整target维度if target.dim() == lprobs.dim() - 1:target = target.unsqueeze(-1)  # 增加一维,与lprobs对齐# 计算NLL损失(负对数似然)# - gather从lprobs中提取target对应的概率nll_loss = -lprobs.gather(dim=-1, index=target)# 计算平滑损失# - 对所有类别的概率求和,作为平滑项smooth_loss = -lprobs.sum(dim=-1, keepdim=True)# 如果有忽略的标签,mask掉对应位置if self.ignore_index is not None:pad_mask = target.eq(self.ignore_index)  # 创建padding masknll_loss.masked_fill_(pad_mask, 0.0)  # 将padding位置的损失置0smooth_loss.masked_fill_(pad_mask, 0.0)  # 将padding位置的平滑损失置0else:nll_loss = nll_loss.squeeze(-1)  # 移除多余维度smooth_loss = smooth_loss.squeeze(-1)  # 移除多余维度if self.reduce:nll_loss = nll_loss.sum()  # 对NLL损失求和smooth_loss = smooth_loss.sum()  # 对平滑损失求和# 计算最终损失# - eps_i: 平滑项的权重,smoothing除以词汇表大小# - 组合NLL损失和平滑损失eps_i = self.smoothing / lprobs.size(-1)loss = (1.0 - self.smoothing) * nll_loss + eps_i * smooth_lossreturn loss
  • 通俗总结:这段代码定义了模型的“评分标准”(损失函数)。它检查模型预测的中文词和正确答案的差距,算出“错误分数”。还加了“标签平滑”,让模型别太自信,给其他可能的词留点余地,防止它死记硬背答案,提高翻译的灵活性。

4.2 优化器(NoamOpt)

  • 功能:动态调整学习率。

  • 代码

    def get_rate(d_model, step_num, warmup_step):# Noam学习率调度公式# - d_model: 模型维度,用于缩放学习率# - step_num: 当前步数# - warmup_step: 预热步数,控制学习率初始增长# - 学习率先线性增加(step_num * warmup_step**(-1.5)),然后按步数的逆平方根衰减(step_num**(-0.5))lr = d_model ** (-0.5) * min(step_num ** (-0.5), step_num * warmup_step ** (-1.5))return lrclass NoamOpt:def __init__(self, model_size, factor, warmup, optimizer):# 初始化Noam优化器self.optimizer = optimizer  # 底层优化器(如Adam)self._step = 0  # 当前步数,初始化为0self.warmup = warmup  # 预热步数,控制学习率增长阶段self.factor = factor  # 学习率缩放因子,调整整体学习率大小self.model_size = model_size  # 模型维度,用于学习率计算self._rate = 0  # 当前学习率,初始化为0def step(self):# 更新学习率并执行优化器步骤self._step += 1  # 步数加1rate = self.rate()  # 计算当前学习率# 更新底层优化器的学习率for p in self.optimizer.param_groups:p['lr'] = rateself._rate = rate  # 保存当前学习率self.optimizer.step()  # 执行优化器更新参数def rate(self, step=None):# 计算当前学习率if step is None:step = self._step  # 如果未提供步数,使用当前步数return self.factor * get_rate(self.model_size, step, self.warmup)  # 返回缩放后的学习率
  • 通俗总结:这段代码控制模型学习的“节奏”。它用一个叫 Noam 的方法动态调整学习率:刚开始学得慢点(预热),像热身运动;后面学得快点,但随着训练深入又慢慢放缓,防止学过头。就像教小孩,先慢慢教,熟练后再加速,最后小心调整。

4.3 训练循环

  • 功能:计算损失并更新参数。

  • 代码

    def train_one_epoch(epoch_itr, model, task, criterion, optimizer, accum_steps=1):# 训练一个epoch# - epoch_itr: 迭代器,提供每个epoch的数据# - model: 待训练的模型# - task: 任务对象,定义输入输出格式# - criterion: 损失函数# - optimizer: 优化器# - accum_steps: 梯度累积步数itr = epoch_itr.next_epoch_itr(shuffle=True)  # 获取数据迭代器,打乱数据顺序itr = iterators.GroupedIterator(itr, accum_steps)  # 分组迭代器,实现梯度累积stats = {"loss": []}  # 记录损失统计信息scaler = GradScaler()  # 自动混合精度缩放器,提升训练效率model.train()  # 设置模型为训练模式# 使用tqdm显示训练进度progress = tqdm.tqdm(itr, desc=f"train epoch {epoch_itr.epoch}", leave=False)for samples in progress:model.zero_grad()  # 清零模型梯度accum_loss = 0  # 计算累积损失初始化sample_size = 0  # 样本token数量初始化# 遍历累积步数内的样本for i, sample in enumerate(samples):if i == 1:torch.cuda.empty_cache()  # 清理cuda缓存,释放显存sample = utils.move_to_cuda(sample, device=device)  # 将数据移动到GPUtarget = sample["target"]  # 获取目标token序列sample_size_i = sample["ntokens"]  # 获取当前样本的token数量sample_size += sample_size_i  # 累加token数量with autocast():  # 使用自动混合精度,节省显存并加速计算net_output = model.forward(**sample["net_input"])  # 前向传播,计算模型输出lprobs = F.log_softmax(net_output[0], -1)  # 计算log概率# 计算损失,view(-1)将张量展平loss = criterion(lprobs.view(-1, lprobs.size(-1)), target.view(-1))accum_loss += loss.item()  # 累积损失值(标量)scaler.scale(loss).backward()  # 缩放损失并反向传播scaler.step(optimizer)  # 更新模型参数scaler.update()  # 更新优化器状态scaler.unscale_(optimizer)  # 反缩放梯度,准备更新参数# 梯度optimizer.multiply_grads(1 / (sample_size orárv0.0))  # 梯度归一化,除以总token数,标准化更新幅度gnorm = nn.utils.clip_grad_norm_(model.parameters(), config.clip_norm)  # 梯度裁剪,防止梯度爆炸scaler.step(optimizer)  # 更新模型参数scaler.update()  # 更新缩放器状态loss_print = accum_loss / sample_size  # 计算平均损失stats["loss'].append(loss_print)  # 记录平均损失progress.set_postfix(loss=loss_print)  # 更新进度条显示# 如果使用wandb,记录训练日志if config.use_wandb:wandb.log({"train/loss": loss_print,  # 平均损失"train/grad_norm": gnorm.item(),  # 梯度范数"train/lr": optimizer.rate(),  # 当前学习率"train/sample_size": sample_size,  # 样本token数})loss_print = np.mean(stats["loss"])  # 计算整个epoch的平均损失logger.info(f"training loss: {loss_print:.4f}")  # 记录损失日志return stats  # 返回训练统计信息
  • 通俗总结:这段代码是模型的“学习过程”。它把数据分成小份喂给模型,模型预测后和正确答案对比,算出“错误分数”(损失)。然后调整模型的“知识点”,让错误更少。用了混合精度和GPU加速,像用高科技教机器;还加了梯度累积和裁剪,避免学得太猛,确保稳扎稳打。

5. 反向翻译(Back-translation)

5.1 训练反向语言

  • 修改配置

    # 修改配置以训练反向语言模型(中文到英文)
    config.source_lang = "zh"  # 设置源语言为中文
    config.target_lang = "en"  config.target_lang = "en"  # 设置目标语言为英文
    config.savedir = "./checkpoints/transformer-back"  # 设置模型检查点保存路径
  • 通俗总结:这段代码调整了训练方向,告诉模型这次要学“中文翻英文”,而不是“英文翻中文”。就像让一个翻译员练习反向翻译,专门为后面生成假数据做准备。

5.2 生成合成数据

下载单语数据
  • 代码

    # 下载中文单语数据
    # - 获取指定URL下载压缩文件
    # - -O 指定保存路径和文件名
    !wget https://github.com/yuhsinchan/ML2022-HW5Dataset/releases/download/v1.0.2/ted_zh_corpus.deduped.gz -O {output_prefix}/ted_zh_corpus.deduped.gz
    # 解压数据
    # - gzip -fkd 提取文件,保留源文件
    !gzip -d {output_prefix}/ted_zh_corpus.deduped.gz
  • 通俗总结:这段代码去网上抓了一堆只有中文的句子(没英文翻译),然后解压出来。这些中文句子是“额外教材”,用来造更多训练数据。

清洗单语数据
  • 代码

    def clean_mono_corpus(input_path, output_path, lang='zh'):# 清洗单语数据# - input_path: 输入文件路径# - output_path: 输出文件路径# - lang: 语言类型,默认为中文with open(input_path, 'r') as in_f:with open(output_path, 'w') as out_f:for line in in_f:line = clean_s(line.strip(), lang)  # 使用clean_s清洗文本# 控制句子长度在1到1000之间if 1 <= len(line) <= 1000:print(line, file=out_f)  # 写入清洗后的句子
    clean_mono_corpus(f'{output_prefix}/ted_zh_corpus.deduped', f'{output_prefix}/mono.clean.zh')
  • 通俗总结:这段代码把刚下载的中文句子再洗一遍,用 clean_s 清理杂乱部分(像去掉括号、规范标点)。还筛掉太长或太短的句子,留下的干净句子存到新文件里,准备下一步用。

分词
  • 代码

    # 加载SentencePiece模型
    spm_model = spm.SentencePieceProcessor(model_file=str(prefix/f'spm8000.model'))
    with open(f'{mono_prefix}/mono.tok.zh', 'w') as out_f:with open(f'{mono_prefix}/mono.clean.zh', 'r') as in_f:for line in in_f:# encode 将文本分割成subword,out_type=str 返回字符串列表tok = spm_model.encode(line.strip(), out_type=str)# 用空格连接subword并写入文件print(' '.join(tok), file=out_f)
  • 通俗总结:这段代码把中文句子切成小片段(subword),就像把“学习”切成“学”和“习”。用之前训练的 SentencePiece 模型把每个句子拆开,再用空格连起来存到文件,方便模型处理。

二值化
  • 代码

    # 使用fairseq对单语数据进行二值化
    # - 配置语言和数据路径
    # - --trainpref: 指定训练数据前缀
    # - --destdir: 指定保存目录
    # - --srcdict, --tgtdict: 指定字典文件
    !python -m fairseq_cli.preprocess \--source-lang 'zh' --target-lang 'en' \--trainpref {mono_prefix}/mono.tok \--destdir ./DATA/data-bin/mono \--srcdict ./DATA/data-bin/ted2020/dict.en.txt \--tgtdict ./DATA/data-bin/ted2020/dict.en.txt \--workers 2
  • 通俗总结:这段代码把切好的中文句子转成机器能快速读的二进制格式,就像把中文“笔记”压缩成电脑文件。用 fairseq 工具,指定中文到英文的词典,确保数据格式和之前的模型一致。

生成预测
  • 代码

    # 使用训练好的模型生成翻译预测
    # - model: 反向翻译模型
    # - task: 任务对象
    # - split: 数据集分割(这里为mono)
    # - outfile: 输出文件路径
    generate_prediction(model, task, split="mono", outfile="./prediction.txt")
  • 通俗总结:这段代码用“中文翻英文”的模型,把中文句子翻译成英文,存到文件里。就像请一个翻译员把中文“教材”翻译成英文,生成一堆假的英文句子,后面用来扩充数据。

5.3 创建新数据集

合并数据
  • 代码

    # 将预测结果(英文)进行分词并保存
    with open(f'{mono_prefix}/mono.tok.en', 'w') as out_f:with open('./prediction.txt', 'r') as in_f:for line in in_f:# 对预测的英文文本进行subword分词tok = spm_model.encode(line.strip(), out_type=str)# 用空格连接subword并写入文件print(' '.join(tok), file=out_f)
  • 通俗总结:这段代码把翻译出来的英文句子也切成小片段(subword),和中文处理一样。然后把这些片段存到文件,形成一组“中文-假英文”的句子对,准备加入训练数据。

二值化新数据
  • 代码

    # 定义二值化数据保存路径
    binpath = Path('./DATA/data-bin/synthetic')
    # 对合成数据进行二值化
    # - 配置参数
    # - --trainpref: 指定训练数据前缀
    # - --destdir: 保存路径
    # - --srcdict, --tgtdict: 指定字典文件
    !python -m fairseq_cli.preprocess \--source-lang 'zh' --target-lang 'en' \--trainpref {mono_prefix}/mono.tok \--destdir {binpath} \--srcdict ./DATA/data-bin/ted2020/dict.en.txt \--tgtdict ./data-bin/ted2020/dict.en.txt.gz \--workers 2
  • 通俗总结:这段代码把“中文-假英文”句子对转成二进制格式,存到新文件夹。就像把新造的翻译数据压缩好,确保格式和原来的数据一样,方便模型一起用。

整合到原始数据
  • 代码

    # 复制原始数据到新目录
    !cp -r ./DATA/data-bin/ted2020/ ./DATA/data-bin/ted2020_with_mono/
    # 将合成数据的英文二进制文件复制并重命名
    !cp ./DATA/data-bin/synthetic/train.zh-en.en.bin ./DATA/data-bin/ted2020_with_mono/train1.en-zh.en.bin
    # 类似地复制其他文件
  • 通俗总结:这段代码把假数据和原始数据“拼”在一起,像把新教材和旧教材装进一个大书包,确保模型能同时学到真翻译和假翻译的数据。

修改配置并训练
  • 代码

    # 配置使用包含合成数据的数据集
    config.datadir = "./DATA/data-bin/ted2020_with_mono"  # 数据目录
    config.source_lang = "en"  # 源语言为英文
    config.target_lang = "zh"  # 目标语言为中文
    config.savedir = "./checkpoints/transformer-bt"  # 检查点保存路径
  • 通俗总结:这段代码告诉模型“现在用新扩充的数据集(包含真假数据)重新学一遍英文到中文的翻译”,并指定保存路径。就像告诉学生用新课本复习,目标是翻译得更厉害。

6. 总结

  • RNN Seq2Seq:结构简单,依赖 GRU 逐步处理序列,适合基础翻译任务。

  • Transformer:用自注意力机制代替 GRU,能同时看整个句子,翻译更快更准。

  • 反向翻译:通过“造假数据”(用中文生成英文对),扩充训练材料,让模型学得更全面。

相关文章:

机器翻译模型笔记

机器翻译学习笔记&#xff08;简体中文&#xff09; 1. 任务概述 目标&#xff1a;将英文句子翻译成简体中文。 示例&#xff1a; 输入&#xff1a;Tom is a student. 输出&#xff1a;汤姆是一个学生。 框架&#xff1a;Seq2Seq&#xff08;序列到序列&#xff09;模型。…...

Ref vs. Reactive:Vue 3 响应式变量的最佳选择指南

Ref vs. Reactive&#xff1a;Vue 3 响应式变量的最佳选择指南 在 Vue 3 的 Composition API 中&#xff0c;ref 和 reactive 是创建响应式数据的两种主要方式。许多开发者经常困惑于何时使用哪种方式。本文将深入对比两者的差异&#xff0c;帮助您做出最佳选择。 核心概念解…...

让视觉基础模型(VFMs)像大语言模型(LLMs)一样“会思考”​

视觉检测器的演进&#xff1a;从 DETR 到 Grounding-DINO DINO-R1 的基础是 Grounding-DINO&#xff0c;而 Grounding-DINO 本身是一系列视觉检测器演进的结果。理解这个发展过程对掌握 DINO-R1 的核心技术至关重要。 DETR&#xff1a;用 Transformer 革新目标检测 在 DETR&…...

现代前端框架的发展与演进

现代前端框架的发展与演进是一个非常值得关注的话题&#xff0c;反映了整个前端生态系统的不断演化与技术深度的提升。以下是这一趋势的详细解析&#xff1a; &#x1f4c8; 现代前端框架的发展与演进 &#x1f539; 第一阶段&#xff1a;jQuery 时代&#xff08;2006-2013&am…...

【LLM-Agent】智能体的记忆缓存设计

note 实践&#xff1a;https://modelscope-agent.readthedocs.io/zh-cn/latest/modules/memory.html 文章目录 note一、Agent的记忆实现二、相关综述三、记忆体的构建四、cursor的记忆设计1. 记忆生成提示词2. 记忆评估提示词 五、记忆相关的MCPReference 一、Agent的记忆实现…...

一起学Spring AI:核心概念

人工智能概念 本节描述了 Spring AI 使用的核心概念。我们建议您仔细阅读&#xff0c;以理解 Spring AI 实现背后的思想。 模型&#xff08;Models&#xff09; 人工智能模型是设计用来处理和生成信息的算法&#xff0c;通常模仿人类的认知功能。通过从大型数据集中学习模式…...

Oracle业务用户的存储过程个数及行数统计

Oracle业务用户的存储过程个数及行数统计 统计所有业务用户存储过程的个数独立定义的存储过程定义在包里的存储过程统计所有业务用户存储过程的总行数独立定义的存储过程定义在包里的存储过程📖 对存储过程进行统计主要用到以下三个系统视图: dba_objects:记录了所有独立创…...

PicSharp(图片压缩工具) v1.1.6

PicSharp 一个简单、高效、灵活的跨平台桌面图像压缩应用程序。软件基于Rust实现&#xff0c;高性能低资源&#xff0c;能快速扫描文件或目录&#xff0c;批处理图像。软件还具备组合压缩策略&#xff0c;TinyPNG提供最佳压缩比&#xff0c;但需要互联网连接&#xff0c;对大量…...

前端文件下载常用方式详解

在前端开发中&#xff0c;实现文件下载是常见的需求。根据不同的场景&#xff0c;我们可以选择不同的方法来实现文件流的下载。本文介绍三种常用的文件下载方式&#xff1a; 使用 axios 发送 JSON 请求下载文件流使用 axios 发送 FormData 请求下载文件流使用原生 form 表单提…...

【DAY42】Grad-CAM与Hook函数

内容来自浙大疏锦行python打卡训练营 浙大疏锦行 知识点: 回调函数lambda函数hook函数的模块钩子和张量钩子Grad-CAM的示例 作业&#xff1a;理解下今天的代码即可 在深度学习中&#xff0c;我们经常需要查看或修改模型中间层的输出或梯度。然而&#xff0c;标准的前向传播和反…...

如何生成和制作PDF文件

在数字化办公的今天&#xff0c;PDF文件已经成为我们工作和学习中不可或缺的一部分。无论是合同、报告、简历&#xff0c;还是电子书、表单&#xff0c;PDF格式都以其跨平台兼容性、不可编辑性和清晰的排版而被广泛使用。但你是否知道&#xff0c;生成和制作PDF文件其实并不复杂…...

【K8S系列】Kubernetes 中 Pod(Java服务)启动缓慢的深度分析与解决方案

本文针对 Kubernetes 中 Java 服务启动时间慢的深度分析与解决方案文章,结合了底层原理、常见原因及具体优化策略: Kubernetes 中 Java 服务启动缓慢的深度分析与高效解决方案 在 Kubernetes 上部署 Java 应用时,启动时间过长是常见痛点,尤其在需要快速扩缩容或滚动更新的…...

【Java学习笔记】StringBuilder类(重点)

StringBuilder&#xff08;重点&#xff09; 1. 基本介绍 是一个可变的字符串序列。该类提供一个与 StringBuffer 兼容的 API&#xff0c;但不保证同步&#xff08;StringBuilder 不是线程安全的&#xff09; 该类被设计用作 StringBuffer 的一个简易替换&#xff0c;用在字符…...

JavaScript ES6 解构:优雅提取数据的艺术

JavaScript ES6 解构&#xff1a;优雅提取数据的艺术 在 JavaScript 的世界中&#xff0c;ES6&#xff08;ECMAScript 2015&#xff09;的推出为开发者带来了许多革命性的特性&#xff0c;其中“解构赋值”&#xff08;Destructuring Assignment&#xff09;无疑是最受欢迎的功…...

iview Switch Tabs TabPane 使用提示Maximum call stack size exceeded堆栈溢出

在vue项目中使用iview 框架部分组件时&#xff0c;直接引入使用报Maximum call stack size exceeded image.png 堆栈溢出 解决方案 更换组件名称就可以了 image.png 或 image.png 就可以了 猜测是因为和vue自己提供的组件名称一致了&#xff0c;重名问题导致的&#xff0c;具体…...

基于Halcon深度学习之分类

***** ***环境准备*** ***系统&#xff1a;win7以上系统 ***显卡&#xff1a;算力3.0以上 ***显卡驱动&#xff1a;10.1以上版本&#xff08;nvidia-smi查看指令&#xff09;***读取深度学习模型*** read_dl_model (pretrained_dl_classifier_compact.hdl, DLModelHandle) ***获…...

零基础在实践中学习网络安全-皮卡丘靶场(第十五期-URL重定向模块)

本期内容和之前的CSRF&#xff0c;File inclusion有联系&#xff0c;复习后可以更好了解 介绍 不安全的url跳转 不安全的url跳转问题可能发生在一切执行了url地址跳转的地方。如果后端采用了前端传进来的(可能是用户传参,或者之前预埋在前端页面的url地址)参数作为了跳转的目…...

技巧小结:根据寄存器手册写常用外设的驱动程序

需求&#xff1a;根据STM32F103寄存器手册写DMA模块的驱动程序 一、分析标准库函数的写法&#xff1a; 各个外设的寄存器地址定义在stm32f10x.h文件中&#xff1a;此文件由芯片厂家提供;内核的有关定义则定义在core_cm3.h文件中&#xff1a;ARM提供; 1、查看外设区域多级划分…...

设计模式(代理设计模式)

代理模式解释清楚&#xff0c;所以如果想对一个类进行功能上增强而又不改变原来的代码情况下&#xff0c;那么只需要让这个类代理类就是我们的顺丰&#xff0c;对吧?并行增强就可以了。具体增强什么?在哪方面增强由代理类进行决定。 代码实现就是使用代理对象代理相关的逻辑…...

从代码学习深度强化学习 - 初探强化学习 PyTorch版

文章目录 前言强化学习的概念强化学习的环境强化学习中的数据强化学习的独特性总结前言 本文将带你初步了解强化学习 (Reinforcement Learning, RL) 的基本概念,并通过 PyTorch 实现一些简单的强化学习算法。强化学习是一种让智能体 (agent) 通过与环境 (environment) 的交互…...

AI大神吴恩达-提示词课程笔记

如何有效编写提示词 在学习如何与语言模型&#xff08;如ChatGPT&#xff09;交互时&#xff0c;编写清晰且高效的提示词&#xff08;Prompt&#xff09;是至关重要的。本课程由ESA提供&#xff0c;重点介绍了提示词工程&#xff08;Prompt Engineering&#xff09;的两个核心…...

ArcGIS Pro 3.4 二次开发 - 地图探索

环境:ArcGIS Pro SDK 3.4 + .NET 8 文章目录 地图探索1 地图视图1.1 测试视图是否为3D1.2 设置视图模式1.3 启用视图链接2 更新地图视图范围2.1 返回上一个相机视图2.2 切换到下一个相机视角2.3 缩放到全图范围2.4 固定放大2.5 固定缩小2.6 缩放到范围2.7 缩放到一个点2.8 缩放…...

ELK日志管理框架介绍

在小铃铛的毕业设计中涉及到了ELK日志管理框架&#xff0c;在调研期间发现在中文中没有很好的对ELK框架进行介绍的文章&#xff0c;因此拟在本文中进行较为详细的实现的介绍。 理论知识 ELK 框架介绍 ELK 是一个流行的开源日志管理解决方案堆栈&#xff0c;由三个核心组件组…...

【Linux】sed 命令详解及使用样例:流式文本编辑器

【Linux】sed 命令详解及使用样例&#xff1a;流式文本编辑器 引言 sed 是 Linux/Unix 系统中一个强大的流式文本编辑器&#xff0c;名称来源于 “Stream EDitor”&#xff08;流编辑器&#xff09;。它允许用户在不打开文件的情况下对文本进行筛选和转换&#xff0c;是命令行…...

机器学习:聚类算法及实战案例

本文目录&#xff1a; 一、聚类算法介绍二、分类&#xff08;一&#xff09;根据聚类颗粒度分类&#xff08;二&#xff09;根据实现方法分类 三、聚类流程四、K值的确定—肘部法&#xff08;一&#xff09;SSE-误差平方和&#xff08;二&#xff09;肘部法确定 K 值 五、代码重…...

预览pdf(url格式和blob格式)

<template><div class"pdf-container"><div v-if"loading" class"loading-state"><a-spin size"large" /></div><div v-else-if"error" class"loading-state">加载失败&…...

【p2p、分布式,区块链笔记 MESH】 论文阅读 Thread/OpenThread Low-Power Wireless Multihop Net

paperauthorThread/OpenThread: A Compromise in Low-Power Wireless Multihop Network Architecture for the Internet of ThingsHyung-Sin Kim, Sam Kumar, and David E. Culler 目录 引言RPL 标准设计目标与架构设计选择与特性shortcomIngs of RPL设计选择的反面影响sImulta…...

for AC500 PLCs 3ADR025003M9903的安全说明

1安全说明 必须遵守特殊的环境条件(例如&#xff0c;由于爆炸性物质、重污染或腐蚀影响的危险区域)。必须在指定的技术数据和系统数据范围内处理和操作设备。该装置不含可维修部件&#xff0c;不得打开。除非另有规定&#xff0c;否则操作过程中必须关闭可拆卸的盖子。拒绝对不…...

moon游戏服务器-demo运行

下载地址 https://github.com/sniper00/MoonDemo redis安装 Redis-x64-3.0.504.msi 服务器配置文件 D:\gitee\moon_server_demo\serverconf.lua 貌似不修改也可以的&#xff0c;redis不要设置密码 windows编译 安装VS2022 Community 下载premake5.exe放MoonDemo\server\moon 双…...

前端(vue)学习笔记(CLASS 7):vuex

vuex概述 vuex是一个vue的状态管理工具&#xff0c;状态就是数据 大白话&#xff1a;vuex是一个插件&#xff0c;可以帮我们管理vue通用的数据&#xff08;多组件共享的数据&#xff09; 场景 1、某个状态在很多个组件来使用&#xff08;个人信息&#xff09; 2、多个组件…...